Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion .env.template
Original file line number Diff line number Diff line change
@@ -1 +1,6 @@
COGVIEW4_PATH=THUDM/CogView4-6B
COGVIEW4_PATH=/share/official_pretrains/hf_home/CogView4-6B
DTYPE=bfloat16
OFFLOAD_TYPE=no_offload
OPENAI_API_KEY=
OPENAI_BASE_URL=
LORA_DIR=
28 changes: 28 additions & 0 deletions .github/PULL_REQUEST_TEMPLATE.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
# Contribution Guide

We welcome your contributions to this repository. To ensure elegant code style and better code quality, we have prepared the following contribution guidelines.

## What We Accept

+ This PR fixes a typo or improves the documentation (if this is the case, you may skip the other checks).
+ This PR fixes a specific issue — please reference the issue number in the PR description. Make sure your code strictly follows the coding standards below.
+ This PR introduces a new feature — please clearly explain the necessity and implementation of the feature. Make sure your code strictly follows the coding standards below.

## Code Style Guide

Good code style is an art. We have prepared a `pyproject.toml` and a `pre-commit` hook to enforce consistent code formatting across the project. You can clean up your code following the steps below:

1. Install the required dependencies:
```shell
pip install ruff pre-commit
```
2. Then, run the following command:
```shell
pre-commit run --all-files
```
If your code complies with the standards, you should not see any errors.

## Naming Conventions

- Please use **English** for naming; do not use Pinyin or other languages. All comments should also be in English.
- Follow **PEP8** naming conventions strictly, and use underscores to separate words. Avoid meaningless names such as `a`, `b`, `c`.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -263,3 +263,4 @@ tmp/

webdoc/
**/wandb/
test*
6 changes: 4 additions & 2 deletions docs/03-Inference/02-API.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,10 +18,11 @@ image = generate_image(
model_id_or_path="THUDM/CogView4-6B",
lora_model_id_or_path=None,
transformer_path=None,
output_file="sunset.png", # Images will be saved here.
height=1024,
width=1024,
)
image.save("sunset.png")


# Text/Image-to-Video generation
video = generate_video(
Expand All @@ -30,10 +31,11 @@ video = generate_video(
model_id_or_path="THUDM/CogVideoX1.5-5B",
lora_model_id_or_path=None,
transformer_path=None,
output_file="cat.mp4", # Videos will be saved here.
num_frames=81,
fps=16,
)
video.save("cat_video.mp4")

```

See function signatures in for more details.
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ dependencies = [
torch = ["numpy", "torch", "torchvision"]
api = [
"fastapi[standard]~=0.115.11",
"fastapi_cli~=0.0.7",
"pydantic_settings~=2.8.1",
"openai~=1.67",
"pydantic-settings~=2.8",
"python-dotenv~=1.0",
Expand Down
3 changes: 3 additions & 0 deletions src/cogkit/api/models/images/generation_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ class ImageGenerationParams(RequestParams):
size: Literal[
"1024x1024", "768x1344", "864x1152", "1344x768", "1152x864", "1440x720", "720x1440"
] = "1024x1024"
num_inference_steps: int = 50
guidance_scale: float = 3.5
lora_path: str | None = None
user: str | None = None
# ! unsupported parameters
# quality: Literal["standard", "hd"] = "standard"
Expand Down
9 changes: 7 additions & 2 deletions src/cogkit/api/routers/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,14 @@ def generations(
status_code=HTTPStatus.NOT_FOUND,
detail=f"The model `{params.model}` does not exist. Supported models: {image_generation.supported_models}",
)
# TODO: add exception handling
image_lst = image_generation.generate(
model=params.model, prompt=params.prompt, size=params.size, num_images=params.n
model=params.model,
prompt=params.prompt,
size=params.size,
num_images=params.n,
num_inference_steps=params.num_inference_steps,
guidance_scale=params.guidance_scale,
lora_path=params.lora_path,
)
image_b64_lst = [ImageInResponse(b64_json=np_to_base64(image)) for image in image_lst]
return ImagesResponse(created=int(time.time()), data=image_b64_lst)
36 changes: 30 additions & 6 deletions src/cogkit/api/services/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@


import numpy as np
import os

import torch
from diffusers import CogView4Pipeline

from cogkit.api.logging import get_logger
Expand All @@ -14,8 +17,14 @@ class ImageGenerationService(object):
def __init__(self, settings: APISettings) -> None:
self._models = {}
if settings.cogview4_path is not None:
cogview4_pl = CogView4Pipeline.from_pretrained(settings.cogview4_path)
cogview4_pl.enable_model_cpu_offload()
cogview4_pl = CogView4Pipeline.from_pretrained(
settings.cogview4_path,
torch_dtype=torch.bfloat16 if settings.dtype == "bfloat16" else torch.float32,
)
if settings.offload_type == "cpu_model_offolad":
cogview4_pl.enable_model_cpu_offload()
else:
cogview4_pl.to("cuda")
cogview4_pl.vae.enable_slicing()
cogview4_pl.vae.enable_tiling()
self._models["cogview-4"] = cogview4_pl
Expand All @@ -36,18 +45,33 @@ def __init__(self, settings: APISettings) -> None:
def supported_models(self) -> list[str]:
return list(self._models.keys())

def generate(self, model: str, prompt: str, size: str, num_images: int) -> list[np.ndarray]:
def generate(
self,
model: str,
prompt: str,
size: str,
num_images: int,
num_inference_steps: int = 50,
guidance_scale: float = 3.5,
lora_path: str | None = None,
) -> list[np.ndarray]:
if model not in self._models:
raise ValueError(f"Model {model} not loaded")
width, height = list(map(int, size.split("x")))
if lora_path is not None:
adapter_name = os.path.basename(lora_path)
print(f"Loaded LORA weights from {adapter_name}")
self._models[model].load_lora_weights(lora_path)
else:
print("Unloading LORA weights")
self._models[model].unload_lora_weights()

# shape of image_np: (n, h, w, c)
image_np = self._models[model](
prompt=prompt,
height=height,
width=width,
num_inference_steps=50,
guidance_scale=3.5,
num_inference_steps=num_inference_steps,
guidance_scale=guidance_scale,
num_images_per_prompt=num_images,
output_type="np",
).images
Expand Down
4 changes: 4 additions & 0 deletions src/cogkit/api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@


from pydantic_settings import BaseSettings, SettingsConfigDict
from typing import Literal


class APISettings(BaseSettings):
Expand All @@ -10,3 +11,6 @@ class APISettings(BaseSettings):
)
_supported_models: tuple[str, ...] = ("cogview-4",)
cogview4_path: str | None = None
dtype: Literal["bfloat16", "float32"] = "bfloat16"
offload_type: Literal["cpu_model_offolad", "no_offload"] = "no_offload"
openai_api_key: str | None = None
1 change: 1 addition & 0 deletions src/cogkit/cli/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ def inference(
output_file or "output.png",
dtype=dtype,
transformer_path=transformer_path,
lora_model_id_or_path=lora_model_id_or_path,
height=height,
width=width,
seed=seed,
Expand Down
17 changes: 4 additions & 13 deletions src/cogkit/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,20 +216,13 @@ def get_prompt_embedding(
def get_image_embedding(
encode_fn: Callable, image: Image.Image, cache_dir: Path, logger: logging.Logger
) -> torch.Tensor:
"""Get encoded image from cache or create new one if not exists.

Args:
encode_fn: Function to project image to embedding.
image: Image to be embedded
cache_dir: Base directory for caching embeddings
logger: Logger instance for logging messages

Returns:
torch.Tensor: Encoded image with shape [C, H, W]
"""
encoded_images_dir = cache_dir / "encoded_images"
encoded_images_dir.mkdir(parents=True, exist_ok=True)

if not hasattr(image, "filename"):
logger.warning("Image object does not have filename attribute, skipping caching.")
return encode_fn(image).to("cpu")

filename = Path(image.filename).stem
filename_hash = str(hashlib.sha256(filename.encode()).hexdigest())
encoded_image_path = encoded_images_dir / (filename_hash + ".safetensors")
Expand All @@ -243,8 +236,6 @@ def get_image_embedding(
else:
encoded_image = encode_fn(image)
encoded_image = encoded_image.to("cpu")

# shape of encoded_image: [C, H, W]
save_file({"encoded_image": encoded_image}, encoded_image_path)
logger.info(
f"Saved encoded image to {encoded_image_path}",
Expand Down
17 changes: 12 additions & 5 deletions src/cogkit/generation/image.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,12 @@

from cogkit.generation.util import before_generation, guess_resolution
from cogkit.logging import get_logger
from cogkit.utils import mkdir, rand_generator, resolve_path
from cogkit.utils import (
load_lora_checkpoint,
mkdir,
rand_generator,
resolve_path,
)

_logger = get_logger(__name__)

Expand All @@ -18,23 +23,24 @@ def generate_image(
prompt: str,
model_id_or_path: str,
output_file: str | Path,
# * params for model loading
dtype: torch.dtype = torch.bfloat16,
transformer_path: str | None = None,
# * params for generated images
lora_model_id_or_path: str | None = None,
lora_rank: int = 128,
height: int | None = None,
width: int | None = None,
# * params for the generation process
num_inference_steps: int = 50,
guidance_scale: float = 3.5,
seed: int | None = 42,
):
pipeline = DiffusionPipeline.from_pretrained(model_id_or_path, torch_dtype=dtype)

if transformer_path is not None:
pipeline.transformer.save_config(transformer_path)
pipeline.transformer = pipeline.transformer.from_pretrained(transformer_path)

if lora_model_id_or_path is not None:
load_lora_checkpoint(pipeline, lora_model_id_or_path, lora_rank)

height, width = guess_resolution(pipeline, height, width)

_logger.info(f"Generation config: height {height}, width {width}.")
Expand All @@ -55,3 +61,4 @@ def generate_image(
mkdir(output_file.parent)
_logger.info("Saving the generated image to path '%s'.", os.fspath(output_file))
batch_image[0].save(output_file)
return batch_image
6 changes: 2 additions & 4 deletions src/cogkit/generation/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,10 +141,8 @@ def before_generation(pipeline: TPipeline) -> None:
# * enables CPU offload for the model.
# turns off if you have multiple GPUs or enough GPU memory(such as H100) and it will cost less time in inference
# and enable to("cuda")
# pipe.to("cuda")

# pipeline.to("cuda")
pipeline.enable_model_cpu_offload()
pipeline.to("cuda")
# pipeline.enable_model_cpu_offload()
# pipe.enable_sequential_cpu_offload()
if hasattr(pipeline, "vae"):
pipeline.vae.enable_slicing()
Expand Down
6 changes: 1 addition & 5 deletions src/cogkit/generation/video.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,19 +41,15 @@ def generate_video(
model_id_or_path: str,
output_file: str | Path,
image_file: str | Path | None = None,
# FIXME: whether to support v2v pipeline
video_file: str | Path | None = None,
# * params for model loading
dtype: torch.dtype = torch.bfloat16,
transformer_path: str | None = None,
lora_model_id_or_path: str | None = None,
lora_rank: int = 128,
# * params for generated videos
height: int | None = None,
width: int | None = None,
num_frames: int | None = None,
fps: int | None = None,
# * params for the generation process
num_inference_steps: int = 50,
guidance_scale: float = 6.0,
seed: int | None = 42,
Expand All @@ -64,7 +60,7 @@ def generate_video(
pipeline.transformer.save_config(transformer_path)
pipeline.transformer = pipeline.transformer.from_pretrained(transformer_path)
if lora_model_id_or_path is not None:
load_lora_checkpoint(lora_model_id_or_path, pipeline, lora_rank)
load_lora_checkpoint(pipeline, lora_model_id_or_path, lora_rank)

height, width = guess_resolution(pipeline, height, width)
num_frames, fps = guess_frames(pipeline, num_frames)
Expand Down
7 changes: 1 addition & 6 deletions src/cogkit/utils/lora.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,5 @@ def load_lora_checkpoint(
lora_model_id_or_path: str,
lora_rank: int,
) -> None:
pipeline.load_lora_weights(
lora_model_id_or_path,
# TODO: ensures the name is correct
weight_name="pytorch_lora_weights.safetensors",
adapter_name="test_1",
)
pipeline.load_lora_weights(lora_model_id_or_path)
pipeline.fuse_lora(components=["transformer"], lora_scale=1 / lora_rank)
Loading