Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.4.5
rev: v0.11.0
hooks:
- id: ruff
args: [--fix, --respect-gitignore, --config=pyproject.toml]
Expand Down
7 changes: 0 additions & 7 deletions Makefile

This file was deleted.

2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

## Introduction

**`cogkit`** is an open-source project that provides a user-friendly interface for researchers and developers to utilize ZhipuAI's [**CogView**](https://huggingface.co/collections/THUDM/cogview-67ac3f241eefad2af015669b) (image generation) and [**CogVideoX**](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) (video generation) models. It streamlines multimodal tasks such as **text-to-image (T2I)**, **text-to-video (T2V)**, and **image-to-video (I2V)**. Users must comply with legal and ethical guidelines to ensure responsible implementation.
CogKit is an open-source project that provides a user-friendly interface for researchers and developers to utilize ZhipuAI's [**CogView**](https://huggingface.co/collections/THUDM/cogview-67ac3f241eefad2af015669b) (image generation) and [**CogVideoX**](https://huggingface.co/collections/THUDM/cogvideo-66c08e62f1685a3ade464cce) (video generation) models. It streamlines multimodal tasks such as **text-to-image (T2I)**, **text-to-video (T2V)**, and **image-to-video (I2V)**. Users must comply with legal and ethical guidelines to ensure responsible implementation.

Visit our [**Docs**](https://thudm.github.io/CogKit) to start.

Expand Down
2 changes: 1 addition & 1 deletion docs/01-Intro.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ slug: /

# Introduction

`cogkit` is a powerful framework for working with cognitive AI models, focusing on multi-modal generation and fine-tuning capabilities. It provides a unified interface for various AI tasks including text-to-image, text-to-video, and image-to-video generation.
CogKit is a powerful framework for working with cognitive AI models, focusing on multi-modal generation and fine-tuning capabilities. It provides a unified interface for various AI tasks including text-to-image, text-to-video, and image-to-video generation.

## Supported Models

Expand Down
27 changes: 3 additions & 24 deletions docs/02-Installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,13 @@

# Installation

`cogkit` can be installed using pip. We recommend using a virtual environment to avoid conflicts with other packages.

## Requirements

- Python 3.10 or higher
- OpenCV and PyTorch
- PyTorch

## Installation Steps

### OpenCV

Please refer to the [opencv-python installation guide](https://github.com/opencv/opencv-python?tab=readme-ov-file#installation-and-usage) for instructions on installing OpenCV according to your system.

### PyTorch

Please refer to the [PyTorch installation guide](https://pytorch.org/get-started/locally/) for instructions on installing PyTorch according to your system.
Expand All @@ -25,13 +19,13 @@ Please refer to the [PyTorch installation guide](https://pytorch.org/get-started
1. Install `cogkit`:

```bash
pip install cogkit@git+https://github.com/THUDM/cogkit.git
pip install "cogkit@git+https://github.com/THUDM/cogkit.git"
```

2. Optional: for video tasks (e.g. text-to-video), install additional dependencies:

```bash
pip install -e .[video]
pip install "cogkit[video]@git+https://github.com/THUDM/cogkit.git"
```

### Verify installation
Expand All @@ -41,18 +35,3 @@ You can verify that cogkit is installed correctly by running:
```bash
cogkit --help
```

and will get:

```text
Usage: cogkit [OPTIONS] COMMAND [ARGS]...

Options:
-v, --verbose Verbosity level (from 0 to 2) [default: 0; 0<=x<=2]
--help Show this message and exit.

Commands:
finetune
inference Generates a video based on the given prompt and saves it to...
launch
```
65 changes: 60 additions & 5 deletions docs/03-Inference/01-CLI.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
<!-- TODO: check this doc -->
# Command-Line Interface

`cogkit` provides a powerful command-line interface (CLI) that allows you to perform various tasks without writing Python code. This guide covers the available commands and their usage.
CogKit provides a powerful command-line interface (CLI) that allows you to perform various tasks without writing Python code. This guide covers the available commands and their usage.

## Overview

Expand Down Expand Up @@ -45,15 +45,70 @@ See `cogkit inference --help` for more information.
<!-- TODO: add docs for launch server -->
## Launch Command

The `launch` command will starts a API server:
The `launch` command starts an API server for image and video generation. Before using this command, you need to install the API dependencies:

<!-- FIXME: Add examples -->
```bash
...
pip install "cogkit[api]@git+https://github.com/THUDM/cogkit.git"
```

Please refer to [API](./02-API.md#api-server) for details on how to interact with the API server using client interfaces.
<!-- FIXME: correct url -->
Before starting the server, make sure to configure the model paths that you want to serve. This step is necessary to specify which models will be available through the API server.

To configure the model paths:

1. Create a `.env` file in your working directory
2. Refer to the [environment template]() and add needed environment variables to specify model paths. For example, to serve `CogView4-6B` as a service, you must specify `COGVIEW4_PATH` in your `.env` file:

```bash
# /your/workdir/.env

COGVIEW4_PATH="THUDM/CogView4-6B" # or local path
# other variables...
```

Then starts a API server, for example:

```bash
cogkit launch
```

:::tip
See `cogkit launch --help` for more information.
:::


### Client Interfaces

The server API is OpenAI-compatible, which means you can use it with any OpenAI client library. Here's an example using the OpenAI Python client:

```python
import base64

from io import BytesIO
from PIL import Image

from openai import OpenAI

client = OpenAI(
api_key="foo",
base_url="http://localhost:8000/v1" # Your server URL
)

# Generate an image from cogview-4
response = client.images.generate(
model="cogview-4",
prompt="a beautiful sunset over mountains",
n=1,
size="1024x1024",
)
image_b64 = response.data[0].b64_json

# Decode the base64 string
image_data = base64.b64decode(image_b64)

# Create an image from the decoded data
image = Image.open(BytesIO(image_data))

# Save the image
image.save("output.png")
```
16 changes: 6 additions & 10 deletions docs/03-Inference/02-API.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# API

`cogkit` provides a powerful inference API for generating images and videos using various AI models. This document covers both the Python API and API server.
CogKit provides a powerful inference API for generating images and videos using various AI models. This document covers both the Python API and API server.

## Python API

Expand All @@ -18,12 +18,15 @@ image = generate_image(
model_id_or_path="THUDM/CogView4-6B",
lora_model_id_or_path=None,
transformer_path=None,
height=1024,
width=1024,
)
image.save("sunset.png")

# Text-to-Video generation
# Text/Image-to-Video generation
video = generate_video(
prompt="a cat playing with a ball",
image_file="path/to/image.png", # Needed for Image-to-Video task
model_id_or_path="THUDM/CogVideoX1.5-5B",
lora_model_id_or_path=None,
transformer_path=None,
Expand All @@ -32,13 +35,6 @@ video = generate_video(
)
video.save("cat_video.mp4")
```
<!-- TODO: add examples for i2v -->

<!-- FIXME: correct url -->
See function signatures in [generation.py](...) for more details.

## API Server

<!-- FIXME: add docs for the API server -->

<!-- TODO: add examples -->
See function signatures in [generation.py]() for more details.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ dependencies = [
torch = ["numpy", "torch", "torchvision"]
api = [
"fastapi[standard]~=0.115.11",
"openai~=1.67",
"pydantic-settings~=2.8",
"python-dotenv~=1.0",
]
Expand Down
2 changes: 1 addition & 1 deletion src/cogkit/api/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def get_application(settings: APISettings | None = None) -> FastAPI:

@asynccontextmanager
async def lifespan(_: FastAPI) -> AsyncIterator[RequestState]:
yield {"image_generation": ImageGenerationService(settings.cogview4_path)}
yield {"image_generation": ImageGenerationService(settings)}

app = FastAPI(lifespan=lifespan)

Expand Down
2 changes: 1 addition & 1 deletion src/cogkit/api/models/images/generation_params.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

class ImageGenerationParams(RequestParams):
prompt: str
model: Literal["cogview-4"] = "cogview-4"
model: str = "cogview-4"
n: int = 1
size: Literal[
"1024x1024", "768x1344", "864x1152", "1344x768", "1152x864", "1440x720", "720x1440"
Expand Down
24 changes: 17 additions & 7 deletions src/cogkit/api/routers/images.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,14 @@


import base64
import io
import time
from http import HTTPStatus
from typing import Annotated

import numpy as np
from fastapi import APIRouter, Depends
from fastapi import APIRouter, Depends, HTTPException
from PIL import Image

from cogkit.api.dependencies import get_image_generation_service
from cogkit.api.models.images import ImageGenerationParams, ImageInResponse, ImagesResponse
Expand All @@ -16,18 +19,25 @@


def np_to_base64(image_array: np.ndarray) -> str:
byte_stream = image_array.tobytes()
base64_str = base64.b64encode(byte_stream).decode("utf-8")
return base64_str
image = Image.fromarray(image_array)
buffered = io.BytesIO()
image.save(buffered, format="PNG")
return base64.b64encode(buffered.getvalue()).decode("utf-8")


@router.post("/generations", response_model=ImagesResponse)
def generations(
image_generation: Annotated[ImageGenerationService, Depends(get_image_generation_service)],
params: ImageGenerationParams,
) -> ImagesResponse:
images_lst = image_generation.generate(
if not image_generation.is_valid_model(params.model):
raise HTTPException(
status_code=HTTPStatus.NOT_FOUND,
detail=f"The model `{params.model}` does not exist. Supported models: {image_generation.supported_models}",
)
# TODO: add exception handling
image_lst = image_generation.generate(
model=params.model, prompt=params.prompt, size=params.size, num_images=params.n
)
images_base64 = [ImageInResponse(b64_json=np_to_base64(image)) for image in images_lst]
return ImagesResponse(created=int(time.time()), data=images_base64)
image_b64_lst = [ImageInResponse(b64_json=np_to_base64(image)) for image in image_lst]
return ImagesResponse(created=int(time.time()), data=image_b64_lst)
51 changes: 43 additions & 8 deletions src/cogkit/api/services/image_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,22 +4,45 @@
import numpy as np
from diffusers import CogView4Pipeline

from cogkit.api.logging import get_logger
from cogkit.api.settings import APISettings

_logger = get_logger(__name__)


class ImageGenerationService(object):
def __init__(self, cogview4_path: str | None) -> None:
def __init__(self, settings: APISettings) -> None:
self._models = {}
if cogview4_path is not None:
cogview4_pl = CogView4Pipeline.from_pretrained(cogview4_path)
if settings.cogview4_path is not None:
cogview4_pl = CogView4Pipeline.from_pretrained(settings.cogview4_path)
cogview4_pl.enable_model_cpu_offload()
cogview4_pl.vae.enable_slicing()
cogview4_pl.vae.enable_titling()
cogview4_pl.vae.enable_tiling()
self._models["cogview-4"] = cogview4_pl

def generate(self, model: str, prompt: str, size: int, num_images: int) -> list[np.ndarray]:
### Check if loaded models are supported
for model in self._models.keys():
if model not in settings._supported_models:
raise ValueError(
f"Registered model {model} not in supported list: {settings._supported_models}"
)

### Check if all supported models are loaded
for model in settings._supported_models:
if model not in self._models:
_logger.warning(f"Model {model} not loaded")

@property
def supported_models(self) -> list[str]:
return list(self._models.keys())

def generate(self, model: str, prompt: str, size: str, num_images: int) -> list[np.ndarray]:
if model not in self._models:
raise ValueError(f"Model {model} not found")
raise ValueError(f"Model {model} not loaded")
width, height = list(map(int, size.split("x")))
images_lst = self._models[model](

# shape of image_np: (n, h, w, c)
image_np = self._models[model](
prompt=prompt,
height=height,
width=width,
Expand All @@ -28,4 +51,16 @@ def generate(self, model: str, prompt: str, size: int, num_images: int) -> list[
num_images_per_prompt=num_images,
output_type="np",
).images
return images_lst
assert image_np.ndim == 4, f"Expected 4D array, got {image_np.ndim}D array"

image_lst = self.postprocess(image_np)
return image_lst

def is_valid_model(self, model: str) -> bool:
return model in self._models

def postprocess(self, image_np: np.ndarray) -> list[np.ndarray]:
image_np = (image_np * 255).round().astype("uint8")
image_lst = np.split(image_np, image_np.shape[0], axis=0)
image_lst = [img.squeeze(0) for img in image_lst]
return image_lst
5 changes: 3 additions & 2 deletions src/cogkit/api/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@

class APISettings(BaseSettings):
model_config = SettingsConfigDict(
extra="ignore", validate_default=True, validate_assignment=True
extra="ignore", validate_default=True, validate_assignment=True, env_file=".env"
)
cogview4_path: str = "THUDM/CogView4-6B"
_supported_models: tuple[str, ...] = ("cogview-4",)
cogview4_path: str | None = None
Loading
Loading