Skip to content

Commit 6b798fe

Browse files
Revert "fix with main?"
This reverts commit 846e9dc, reversing changes made to e267385.
1 parent 846e9dc commit 6b798fe

File tree

9 files changed

+12
-216
lines changed

9 files changed

+12
-216
lines changed

.env.template

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,3 @@ OFFLOAD_TYPE=cpu_model_offload
66
COGVIEW4_PATH=THUDM/CogView4-6B
77
# Optional, only needed when you don't want to use the default transformer in COGVIEW4_PATH
88
# COGVIEW4_TRANSFORMER_PATH=
9-
10-
### OPENAI API
11-
12-
OPENAI_API_KEY=
13-
OPENAI_BASE_URL=
14-
15-
### LORA PATH FOR GUI(test)
16-
17-
LORA_DIR=

src/cogkit/api/python/generation/util.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,11 @@
66
CogVideoXDPMScheduler,
77
CogVideoXImageToVideoPipeline,
88
CogVideoXPipeline,
9-
CogView4ControlPipeline,
109
CogView4Pipeline,
1110
)
1211

1312
TVideoPipeline = CogVideoXPipeline | CogVideoXImageToVideoPipeline
1413
TPipeline = CogView4Pipeline | TVideoPipeline
15-
CogviewPipline = CogView4Pipeline | CogView4ControlPipeline
1614

1715

1816
def _is_cogvideox1_0(pipeline: TVideoPipeline) -> bool:
@@ -105,7 +103,7 @@ def guess_resolution(
105103
height: int | None = None,
106104
width: int | None = None,
107105
) -> tuple[int, int]:
108-
if isinstance(pipeline, CogviewPipline):
106+
if isinstance(pipeline, CogView4Pipeline):
109107
return _guess_cogview_resolution(pipeline, height=height, width=width)
110108
if isinstance(pipeline, TVideoPipeline):
111109
return _guess_cogvideox_resolution(pipeline, height=height, width=width)

src/cogkit/api/services/image_generation.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
import numpy as np
55
import os
6+
67
import torch
78

89
from cogkit.api.logging import get_logger
@@ -26,12 +27,14 @@ def __init__(self, settings: APISettings) -> None:
2627
before_generation(cogview4_pl, settings.offload_type)
2728
self._models["cogview-4"] = cogview4_pl
2829

30+
### Check if loaded models are supported
2931
for model in self._models.keys():
3032
if model not in settings._supported_models:
3133
raise ValueError(
3234
f"Registered model {model} not in supported list: {settings._supported_models}"
3335
)
3436

37+
### Check if all supported models are loaded
3538
for model in settings._supported_models:
3639
if model not in self._models:
3740
_logger.warning(f"Model {model} not loaded")
@@ -54,6 +57,7 @@ def generate(
5457
raise ValueError(f"Model {model} not loaded")
5558
width, height = list(map(int, size.split("x")))
5659

60+
# TODO: Refactor this to switch by LoRA endpoint API
5761
if lora_path is not None:
5862
adapter_name = os.path.basename(lora_path)
5963
_logger.info(f"Loaded LORA weights from {adapter_name}")
@@ -64,12 +68,13 @@ def generate(
6468

6569
output = generate_image(
6670
prompt=prompt,
71+
pipeline=self._models[model],
72+
num_images_per_prompt=num_images,
73+
output_type="np",
6774
height=height,
6875
width=width,
6976
num_inference_steps=num_inference_steps,
7077
guidance_scale=guidance_scale,
71-
num_images_per_prompt=num_images,
72-
output_type="np",
7378
)
7479

7580
image_lst = self.postprocess(output)
@@ -79,7 +84,6 @@ def is_valid_model(self, model: str) -> bool:
7984
return model in self._models
8085

8186
def postprocess(self, image_np: np.ndarray) -> list[np.ndarray]:
82-
image_np = (image_np * 255).round().astype("uint8")
8387
image_lst = np.split(image_np, image_np.shape[0], axis=0)
8488
image_lst = [img.squeeze(0) for img in image_lst]
8589
return image_lst

src/cogkit/api/settings.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,5 +19,3 @@ class APISettings(BaseSettings):
1919
# cogview-4 related settings
2020
cogview4_path: str | None = None
2121
cogview4_transformer_path: str | None = None
22-
openai_api_key: str | None = None
23-
lora_dir: str | None = None

src/cogkit/cli/inference.py

Lines changed: 2 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,10 @@
3131
type=click.Path(dir_okay=False, writable=True),
3232
help="the path to save the generated image/video. If not provided, the generated image/video will be saved to 'output.png/mp4'.",
3333
)
34-
@click.option(
35-
"--task",
36-
type=click.Choice(
37-
choices=[mode.value for mode in GenerationMode],
38-
case_sensitive=False,
39-
),
40-
help="the generation task",
41-
)
4234
@click.option(
4335
"--image_file",
4436
type=click.Path(exists=True, file_okay=True, dir_okay=False, readable=True),
45-
help="the image to guide the video generation (for i2v or ct2i generation task)",
37+
help="the image to guide the image/video generation (for i2i/i2v generation task)",
4638
)
4739
@click.option(
4840
"--dtype",
@@ -158,10 +150,7 @@ def inference(
158150
_logger.info("Saving the generated video to path '%s'.", os.fspath(output_file))
159151
export_to_video(output[0], output_file, fps=fps)
160152

161-
elif task in (
162-
GenerationMode.TextToImage,
163-
GenerationMode.CtrlTextToImage,
164-
):
153+
elif task in (GenerationMode.TextToImage,):
165154
batched_images = generate_image(
166155
prompt=prompt,
167156
pipeline=pipeline,

src/cogkit/types/generation_mode.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,3 @@ class GenerationMode(enum.Enum):
88
TextToVideo = "t2v"
99
ImageToVideo = "i2v"
1010
TextToImage = "t2i"
11-
CtrlTextToImage = "ct2i"

src/cogkit/utils/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from cogkit.utils.lora import load_lora_checkpoint, unload_lora_checkpoint
77
from cogkit.utils.misc import guess_generation_mode
88
from cogkit.utils.path import mkdir, resolve_path
9-
from cogkit.utils.prompt import convert_prompt
109
from cogkit.utils.random import rand_generator
1110
from cogkit.utils.load import load_pipeline
1211

@@ -20,5 +19,4 @@
2019
"resolve_path",
2120
"rand_generator",
2221
"load_pipeline",
23-
"convert_prompt",
2422
]

src/cogkit/utils/misc.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33

44
from diffusers import DiffusionPipeline
55
from PIL import Image
6-
from pathlib import Path
6+
77
from cogkit.logging import get_logger
88
from cogkit.types import GenerationMode
99

@@ -14,7 +14,6 @@
1414
"CogView4Pipeline",
1515
"CogVideoXPipeline",
1616
"CogVideoXImageToVideoPipeline",
17-
"CogView4ControlPipeline",
1817
)
1918

2019

@@ -37,23 +36,6 @@ def _check_text_to_image_params(
3736
)
3837

3938

40-
def _check_control_text_to_image_params(
41-
pl_cls_name: str,
42-
generation_mode: GenerationMode | None,
43-
image: str | Path | None,
44-
) -> None:
45-
if generation_mode is not None and generation_mode != GenerationMode.CtrlTextToImage:
46-
_logger.warning(
47-
"The pipeline `%s` does not support `%s` task. Will try the `%s` task.",
48-
pl_cls_name,
49-
generation_mode.value,
50-
GenerationMode.CtrlTextToImage,
51-
)
52-
if image is not None:
53-
err_msg = f"Image input is required in the image2video pipeline. Please provide a regular image file (image_file = {image})."
54-
raise ValueError(err_msg)
55-
56-
5739
def _check_image_to_video_params(
5840
pl_cls_name: str,
5941
generation_mode: GenerationMode | None,
@@ -84,7 +66,7 @@ def guess_generation_mode(
8466
if generation_mode is not None:
8567
generation_mode = GenerationMode(generation_mode)
8668

87-
if pl_cls_name == "CogView4Pipeline":
69+
if pl_cls_name.startswith("CogView"):
8870
# TextToImage
8971
_check_text_to_image_params(pl_cls_name, generation_mode, image)
9072
return GenerationMode.TextToImage
@@ -93,11 +75,6 @@ def guess_generation_mode(
9375
_check_image_to_video_params(pl_cls_name, generation_mode, image)
9476
return GenerationMode.ImageToVideo
9577

96-
if pl_cls_name == "CogView4ControlPipeline":
97-
# Control TextToImage
98-
_check_control_text_to_image_params(pl_cls_name, generation_mode, image)
99-
return GenerationMode.CtrlTextToImage
100-
10178
if image is not None:
10279
_logger.warning(
10380
"Pipeline `%s` does not support image input. Will ignore the image file.",

src/cogkit/utils/prompt.py

Lines changed: 0 additions & 158 deletions
This file was deleted.

0 commit comments

Comments
 (0)