33
44import numpy as np
55import os
6+
67import torch
78
89from cogkit .api .logging import get_logger
@@ -26,12 +27,14 @@ def __init__(self, settings: APISettings) -> None:
2627 before_generation (cogview4_pl , settings .offload_type )
2728 self ._models ["cogview-4" ] = cogview4_pl
2829
30+ ### Check if loaded models are supported
2931 for model in self ._models .keys ():
3032 if model not in settings ._supported_models :
3133 raise ValueError (
3234 f"Registered model { model } not in supported list: { settings ._supported_models } "
3335 )
3436
37+ ### Check if all supported models are loaded
3538 for model in settings ._supported_models :
3639 if model not in self ._models :
3740 _logger .warning (f"Model { model } not loaded" )
@@ -54,6 +57,7 @@ def generate(
5457 raise ValueError (f"Model { model } not loaded" )
5558 width , height = list (map (int , size .split ("x" )))
5659
60+ # TODO: Refactor this to switch by LoRA endpoint API
5761 if lora_path is not None :
5862 adapter_name = os .path .basename (lora_path )
5963 _logger .info (f"Loaded LORA weights from { adapter_name } " )
@@ -64,12 +68,13 @@ def generate(
6468
6569 output = generate_image (
6670 prompt = prompt ,
71+ pipeline = self ._models [model ],
72+ num_images_per_prompt = num_images ,
73+ output_type = "np" ,
6774 height = height ,
6875 width = width ,
6976 num_inference_steps = num_inference_steps ,
7077 guidance_scale = guidance_scale ,
71- num_images_per_prompt = num_images ,
72- output_type = "np" ,
7378 )
7479
7580 image_lst = self .postprocess (output )
@@ -79,7 +84,6 @@ def is_valid_model(self, model: str) -> bool:
7984 return model in self ._models
8085
8186 def postprocess (self , image_np : np .ndarray ) -> list [np .ndarray ]:
82- image_np = (image_np * 255 ).round ().astype ("uint8" )
8387 image_lst = np .split (image_np , image_np .shape [0 ], axis = 0 )
8488 image_lst = [img .squeeze (0 ) for img in image_lst ]
8589 return image_lst
0 commit comments