From 6339afe245f4cd649d653f346fed62b30c25b8b2 Mon Sep 17 00:00:00 2001 From: zhoujiayi <2739647045@qq.com> Date: Fri, 6 Mar 2026 17:17:26 +0800 Subject: [PATCH] =?UTF-8?q?=E6=A8=A1=E5=9E=8B=E7=AB=AF=E7=82=B9=E9=85=8D?= =?UTF-8?q?=E7=BD=AE=E5=8C=96?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .env.example | 44 ++++++++++++++ src/apps/comic_gen/api.py | 4 +- src/apps/comic_gen/assets.py | 3 +- src/apps/comic_gen/llm.py | 20 ++++--- src/apps/comic_gen/models.py | 9 +-- src/apps/comic_gen/pipeline.py | 11 ++-- src/model_request_settings.py | 103 +++++++++++++++++++++++++++++++++ src/models/doubao.py | 8 ++- src/models/image.py | 70 +++++++++++++--------- src/models/kling.py | 8 ++- src/models/qwen_vl.py | 5 +- src/models/wanx.py | 34 +++++++---- 12 files changed, 255 insertions(+), 64 deletions(-) create mode 100644 src/model_request_settings.py diff --git a/.env.example b/.env.example index cea827e2..78703d42 100644 --- a/.env.example +++ b/.env.example @@ -1,6 +1,16 @@ # 阿里云 DashScope API Key (用于通义千问等模型) DASHSCOPE_API_KEY=your_dashscope_api_key_here +# DashScope 模型请求端点 +# 视频生成端点 +DASHSCOPE_VIDEO_CREATE_URL=https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis +# 文生图端点 +DASHSCOPE_IMAGE_T2I_URL=https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation +# 图生图端点 +DASHSCOPE_IMAGE_I2I_URL=https://dashscope.aliyuncs.com/api/v1/services/aigc/image-generation/generation +# 任务查询端点(必须包含 {task_id}) +DASHSCOPE_TASK_QUERY_URL_TEMPLATE=https://dashscope.aliyuncs.com/api/v1/tasks/{task_id} + # 阿里云 Access Key (用于 OSS、视频超分等服务) ALIBABA_CLOUD_ACCESS_KEY_ID=your_aliyun_access_key_id_here ALIBABA_CLOUD_ACCESS_KEY_SECRET=your_aliyun_access_key_secret_here @@ -13,3 +23,37 @@ OSS_BASE_PATH=comic_gen/ # API 服务配置 API_HOST=0.0.0.0 API_PORT=8000 + +# =============================== +# 模型名称配置 +# =============================== +# Wanx 默认模型 +WANX_T2V_MODEL_NAME_DEFAULT=wan2.5-t2v-preview +WANX_I2V_MODEL_NAME_DEFAULT=wan2.6-i2v +WANX_R2V_MODEL_NAME_DEFAULT=wan2.6-r2v + +# Wanx 走 HTTP 接口的模型列表(逗号分隔) +WANX_HTTP_I2V_MODEL_NAMES=wan2.6-i2v,wan2.5-i2v +WANX_HTTP_R2V_MODEL_NAMES=wan2.6-r2v + +# Wanx 图像模型 +WANX_IMAGE_T2I_MODEL_NAME_DEFAULT=wan2.6-t2i +WANX_IMAGE_I2I_MODEL_NAME_DEFAULT=wan2.6-image +WANX_IMAGE_HTTP_T2I_MODEL_NAMES=wan2.6-t2i +WANX_IMAGE_HTTP_I2I_MODEL_NAMES=wan2.6-image +WANX_IMAGE_FOUR_REF_MODELS=wan2.6-image + +# LLM 各环节模型 +LLM_PARSE_NOVEL_MODEL_NAME=qwen-max +LLM_STORYBOARD_ANALYSIS_MODEL_NAME=qwen-max +LLM_STYLE_RECOMMEND_MODEL_NAME=qwen-plus +LLM_STORYBOARD_POLISH_MODEL_NAME=qwen-plus +LLM_VIDEO_POLISH_MODEL_NAME=qwen-plus +LLM_R2V_POLISH_MODEL_NAME=qwen-plus + +# 其他模型默认名称 / URL +QWEN_VL_MODEL_NAME_DEFAULT=qwen-vl-plus +DOUBAO_BASE_URL=https://ark.cn-beijing.volces.com/api/v3 +DOUBAO_MODEL_NAME_DEFAULT=doubao-seedance-1-0-pro-fast-251015 +KLING_BASE_URL=https://api.klingai.com/v1 +KLING_MODEL_NAME_DEFAULT=kling-v2-5-turbo diff --git a/src/apps/comic_gen/api.py b/src/apps/comic_gen/api.py index 26333d25..28bfef47 100644 --- a/src/apps/comic_gen/api.py +++ b/src/apps/comic_gen/api.py @@ -14,6 +14,7 @@ from .pipeline import ComicGenPipeline from .models import Script, VideoTask from .llm import ScriptProcessor +from ...model_request_settings import MODEL_REQUEST_SETTINGS from ...utils.oss_utils import OSSImageUploader, sign_oss_urls_in_data from ...utils import setup_logging from fastapi.responses import JSONResponse @@ -666,7 +667,7 @@ class CreateVideoTaskRequest(BaseModel): prompt_extend: bool = True negative_prompt: Optional[str] = None batch_size: int = 1 - model: str = "wan2.6-i2v" + model: str = MODEL_REQUEST_SETTINGS.wanx_i2v_model_name_default shot_type: str = "single" # 'single' or 'multi' (only for wan2.6-i2v) generation_mode: str = "i2v" # 'i2v' (image-to-video) or 'r2v' (reference-to-video) reference_video_urls: List[str] = [] # Reference video URLs for R2V (max 3) @@ -1733,4 +1734,3 @@ async def reorder_frames(script_id: str, request: ReorderFramesRequest): pipeline._save_data() return {"status": "success", "message": "Frames reordered", "frame_count": len(script.frames)} - diff --git a/src/apps/comic_gen/assets.py b/src/apps/comic_gen/assets.py index 6ead4413..9e8d950f 100644 --- a/src/apps/comic_gen/assets.py +++ b/src/apps/comic_gen/assets.py @@ -5,6 +5,7 @@ from urllib.parse import quote from .models import Character, Scene, Prop, GenerationStatus, ImageAsset, ImageVariant, MAX_VARIANTS_PER_ASSET from ...models.image import WanxImageModel +from ...model_request_settings import MODEL_REQUEST_SETTINGS from ...utils import get_logger from ...utils.oss_utils import is_object_key @@ -138,7 +139,7 @@ def generate_character(self, character: Character, generation_type: str = "all", effective_generation_prompt = generation_prompt if ref_image_path: # Override to I2I model when using reference image - effective_model_name = i2i_model_name or "wan2.6-image" + effective_model_name = i2i_model_name or MODEL_REQUEST_SETTINGS.wanx_image_i2i_model_name_default logger.debug(f"Reverse generation: Using I2I model {effective_model_name} with reference image") # Enhance prompt for reverse generation to emphasize reference consistency (only if not already present) diff --git a/src/apps/comic_gen/llm.py b/src/apps/comic_gen/llm.py index b2dc5f70..656bf036 100644 --- a/src/apps/comic_gen/llm.py +++ b/src/apps/comic_gen/llm.py @@ -9,12 +9,19 @@ from .models import Script, Character, Scene, Prop, StoryboardFrame, GenerationStatus from ...utils import get_logger +from ...model_request_settings import MODEL_REQUEST_SETTINGS logger = get_logger(__name__) class ScriptProcessor: def __init__(self, api_key: str = None): self._api_key = api_key + self.parse_novel_model_name = MODEL_REQUEST_SETTINGS.llm_parse_novel_model_name + self.storyboard_analysis_model_name = MODEL_REQUEST_SETTINGS.llm_storyboard_analysis_model_name + self.style_recommend_model_name = MODEL_REQUEST_SETTINGS.llm_style_recommend_model_name + self.storyboard_polish_model_name = MODEL_REQUEST_SETTINGS.llm_storyboard_polish_model_name + self.video_polish_model_name = MODEL_REQUEST_SETTINGS.llm_video_polish_model_name + self.r2v_polish_model_name = MODEL_REQUEST_SETTINGS.llm_r2v_polish_model_name @property def api_key(self): @@ -40,8 +47,7 @@ def parse_novel(self, title: str, text: str) -> Script: dashscope.api_key = self.api_key response = dashscope.Generation.call( - # model='deepseek-v3.2', - model='qwen-max', + model=self.parse_novel_model_name, prompt=prompt, result_format='message', ) @@ -416,7 +422,7 @@ def analyze_script_for_styles(self, script_text: str) -> List[Dict[str, Any]]: dashscope.api_key = self.api_key response = dashscope.Generation.call( - model='qwen-plus', + model=self.style_recommend_model_name, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": user_prompt} @@ -671,7 +677,7 @@ def analyze_to_storyboard(self, text: str, entities_json: Dict[str, Any]) -> Lis dashscope.api_key = self.api_key response = dashscope.Generation.call( - model='qwen-max', + model=self.storyboard_analysis_model_name, messages=[ {"role": "system", "content": system_prompt}, {"role": "user", "content": "请开始生成分镜帧列表,确保覆盖剧本中的所有内容。"} @@ -792,7 +798,7 @@ def polish_storyboard_prompt(self, draft_prompt: str, assets: List[Dict[str, Any dashscope.api_key = self.api_key response = dashscope.Generation.call( - model='qwen-plus', + model=self.storyboard_polish_model_name, prompt=system_prompt, result_format='message', response_format={'type': 'json_object'} @@ -866,7 +872,7 @@ def polish_video_prompt(self, draft_prompt: str) -> Dict[str, str]: dashscope.api_key = self.api_key response = dashscope.Generation.call( - model='qwen-plus', + model=self.video_polish_model_name, messages=[ {'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': draft_prompt} @@ -964,7 +970,7 @@ def polish_r2v_prompt(self, draft_prompt: str, slots: List[Dict[str, str]]) -> D dashscope.api_key = self.api_key response = dashscope.Generation.call( - model='qwen-plus', + model=self.r2v_polish_model_name, messages=[ {'role': 'system', 'content': system_prompt}, {'role': 'user', 'content': draft_prompt} diff --git a/src/apps/comic_gen/models.py b/src/apps/comic_gen/models.py index 1088e8d8..09aa8caa 100644 --- a/src/apps/comic_gen/models.py +++ b/src/apps/comic_gen/models.py @@ -2,6 +2,7 @@ from enum import Enum import time from pydantic import BaseModel, Field +from ...model_request_settings import MODEL_REQUEST_SETTINGS class AspectRatio(str, Enum): SQUARE = "1:1" @@ -76,7 +77,7 @@ class VideoTask(BaseModel): audio_url: Optional[str] = Field(None, description="URL of generated/uploaded audio") prompt_extend: bool = Field(True, description="Whether to use prompt extension") negative_prompt: Optional[str] = Field(None, description="Negative prompt") - model: str = Field("wan2.6-i2v", description="Model used for generation") + model: str = Field(MODEL_REQUEST_SETTINGS.wanx_i2v_model_name_default, description="Model used for generation") shot_type: str = Field("single", description="Shot type: 'single' or 'multi' (only for wan2.6-i2v)") generation_mode: str = Field("i2v", description="Generation mode: 'i2v' (image-to-video) or 'r2v' (reference-to-video)") reference_video_urls: List[str] = Field(default_factory=list, description="Reference video URLs for R2V generation (max 3)") @@ -221,9 +222,9 @@ class StoryboardFrame(BaseModel): class ModelSettings(BaseModel): """Model selection settings for different generation stages""" - t2i_model: str = Field("wan2.6-t2i", description="Text-to-Image model for Assets") - i2i_model: str = Field("wan2.6-image", description="Image-to-Image model for Storyboard") - i2v_model: str = Field("wan2.6-i2v", description="Image-to-Video model for Motion") + t2i_model: str = Field(MODEL_REQUEST_SETTINGS.wanx_image_t2i_model_name_default, description="Text-to-Image model for Assets") + i2i_model: str = Field(MODEL_REQUEST_SETTINGS.wanx_image_i2i_model_name_default, description="Image-to-Image model for Storyboard") + i2v_model: str = Field(MODEL_REQUEST_SETTINGS.wanx_i2v_model_name_default, description="Image-to-Video model for Motion") character_aspect_ratio: str = Field("9:16", description="Aspect ratio for Characters (9:16, 16:9, 1:1)") scene_aspect_ratio: str = Field("16:9", description="Aspect ratio for Scenes (9:16, 16:9, 1:1)") prop_aspect_ratio: str = Field("1:1", description="Aspect ratio for Props (9:16, 16:9, 1:1)") diff --git a/src/apps/comic_gen/pipeline.py b/src/apps/comic_gen/pipeline.py index 9dafbc10..4ffee5df 100644 --- a/src/apps/comic_gen/pipeline.py +++ b/src/apps/comic_gen/pipeline.py @@ -14,6 +14,7 @@ from .video import VideoGenerator from .audio import AudioGenerator from .export import ExportManager +from ...model_request_settings import MODEL_REQUEST_SETTINGS from ...utils import get_logger from ...utils.oss_utils import is_object_key from ...utils.system_check import get_ffmpeg_path, get_ffmpeg_install_instructions @@ -1123,7 +1124,7 @@ def generate_motion_ref( duration=duration, created_at=time.time(), generate_audio=bool(audio_url), - model="wan2.6-i2v", + model=MODEL_REQUEST_SETTINGS.wanx_i2v_model_name_default, generation_mode="i2v" # Image to video (motion reference) ) @@ -1273,7 +1274,7 @@ def generate_video(self, script_id: str) -> Script: self._save_data() return script - def create_video_task(self, script_id: str, image_url: str, prompt: str, duration: int = 5, seed: int = None, resolution: str = "720p", generate_audio: bool = False, audio_url: str = None, prompt_extend: bool = True, negative_prompt: str = None, model: str = "wan2.6-i2v", frame_id: str = None, shot_type: str = "single", generation_mode: str = "i2v", reference_video_urls: list = None) -> Tuple[Script, str]: + def create_video_task(self, script_id: str, image_url: str, prompt: str, duration: int = 5, seed: int = None, resolution: str = "720p", generate_audio: bool = False, audio_url: str = None, prompt_extend: bool = True, negative_prompt: str = None, model: str = MODEL_REQUEST_SETTINGS.wanx_i2v_model_name_default, frame_id: str = None, shot_type: str = "single", generation_mode: str = "i2v", reference_video_urls: list = None) -> Tuple[Script, str]: """Creates a new video generation task.""" script = self.get_script(script_id) if not script: @@ -1283,7 +1284,7 @@ def create_video_task(self, script_id: str, image_url: str, prompt: str, duratio # If R2V mode is selected, use the R2V model if generation_mode == "r2v": - model = "wan2.6-r2v" + model = MODEL_REQUEST_SETTINGS.wanx_r2v_model_name_default # Snapshot the input image to ensure consistency snapshot_url = image_url @@ -1687,7 +1688,7 @@ def create_asset_video_task(self, script_id: str, asset_id: str, asset_type: str prompt=prompt or f"Cinematic shot of {target_asset.name}", status="pending", duration=duration, - model="wan2.6-r2v", # Force R2V model + model=MODEL_REQUEST_SETTINGS.wanx_r2v_model_name_default, # Force configured R2V model created_at=time.time() ) @@ -1888,7 +1889,7 @@ def create_asset_video_task(self, script_id: str, asset_id: str, asset_type: str status="pending", duration=duration, resolution=resolution, - model="wan2.6-i2v", # Asset video uses I2V + model=MODEL_REQUEST_SETTINGS.wanx_i2v_model_name_default, # Asset video uses configured I2V model created_at=time.time() ) diff --git a/src/model_request_settings.py b/src/model_request_settings.py new file mode 100644 index 00000000..aae9d37b --- /dev/null +++ b/src/model_request_settings.py @@ -0,0 +1,103 @@ +import os +from pathlib import Path +from typing import List + +from dotenv import load_dotenv + + +_PROJECT_ROOT = Path(__file__).resolve().parent.parent +_ENV_PATH = _PROJECT_ROOT / ".env" +if _ENV_PATH.exists(): + load_dotenv(_ENV_PATH, override=False) + + +def _split_csv(value: str, default: List[str]) -> List[str]: + if not value: + return default + items = [item.strip() for item in value.split(",") if item.strip()] + return items or default + + +class ModelRequestSettings: + # DashScope request URLs + dashscope_video_create_url: str = os.getenv( + "DASHSCOPE_VIDEO_CREATE_URL", + "https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis", + ) + dashscope_image_t2i_url: str = os.getenv( + "DASHSCOPE_IMAGE_T2I_URL", + "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation", + ) + dashscope_image_i2i_url: str = os.getenv( + "DASHSCOPE_IMAGE_I2I_URL", + "https://dashscope.aliyuncs.com/api/v1/services/aigc/image-generation/generation", + ) + dashscope_task_query_url_template: str = os.getenv( + "DASHSCOPE_TASK_QUERY_URL_TEMPLATE", + "https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}", + ) + + # Wanx model names + wanx_t2v_model_name_default: str = os.getenv( + "WANX_T2V_MODEL_NAME_DEFAULT", + "wan2.5-t2v-preview", + ) + wanx_i2v_model_name_default: str = os.getenv( + "WANX_I2V_MODEL_NAME_DEFAULT", + "wan2.6-i2v", + ) + wanx_r2v_model_name_default: str = os.getenv( + "WANX_R2V_MODEL_NAME_DEFAULT", + "wan2.6-r2v", + ) + wanx_http_i2v_model_names: List[str] = _split_csv( + os.getenv("WANX_HTTP_I2V_MODEL_NAMES", ""), + ["wan2.6-i2v", "wan2.5-i2v"], + ) + wanx_http_r2v_model_names: List[str] = _split_csv( + os.getenv("WANX_HTTP_R2V_MODEL_NAMES", ""), + ["wan2.6-r2v"], + ) + + # Wanx image model names + wanx_image_t2i_model_name_default: str = os.getenv( + "WANX_IMAGE_T2I_MODEL_NAME_DEFAULT", + "wan2.6-t2i", + ) + wanx_image_i2i_model_name_default: str = os.getenv( + "WANX_IMAGE_I2I_MODEL_NAME_DEFAULT", + "wan2.6-image", + ) + wanx_image_four_ref_models: List[str] = _split_csv( + os.getenv("WANX_IMAGE_FOUR_REF_MODELS", ""), + ["wan2.6-image"], + ) + wanx_image_http_t2i_model_names: List[str] = _split_csv( + os.getenv("WANX_IMAGE_HTTP_T2I_MODEL_NAMES", ""), + ["wan2.6-t2i"], + ) + wanx_image_http_i2i_model_names: List[str] = _split_csv( + os.getenv("WANX_IMAGE_HTTP_I2I_MODEL_NAMES", ""), + ["wan2.6-image"], + ) + + # LLM model names + llm_parse_novel_model_name: str = os.getenv("LLM_PARSE_NOVEL_MODEL_NAME", "qwen-max") + llm_storyboard_analysis_model_name: str = os.getenv("LLM_STORYBOARD_ANALYSIS_MODEL_NAME", "qwen-max") + llm_style_recommend_model_name: str = os.getenv("LLM_STYLE_RECOMMEND_MODEL_NAME", "qwen-plus") + llm_storyboard_polish_model_name: str = os.getenv("LLM_STORYBOARD_POLISH_MODEL_NAME", "qwen-plus") + llm_video_polish_model_name: str = os.getenv("LLM_VIDEO_POLISH_MODEL_NAME", "qwen-plus") + llm_r2v_polish_model_name: str = os.getenv("LLM_R2V_POLISH_MODEL_NAME", "qwen-plus") + + # Other model providers + qwen_vl_model_name_default: str = os.getenv("QWEN_VL_MODEL_NAME_DEFAULT", "qwen-vl-plus") + doubao_base_url: str = os.getenv("DOUBAO_BASE_URL", "https://ark.cn-beijing.volces.com/api/v3") + doubao_model_name_default: str = os.getenv( + "DOUBAO_MODEL_NAME_DEFAULT", + "doubao-seedance-1-0-pro-fast-251015", + ) + kling_base_url: str = os.getenv("KLING_BASE_URL", "https://api.klingai.com/v1") + kling_model_name_default: str = os.getenv("KLING_MODEL_NAME_DEFAULT", "kling-v2-5-turbo") + + +MODEL_REQUEST_SETTINGS = ModelRequestSettings() diff --git a/src/models/doubao.py b/src/models/doubao.py index 20049b9e..fce55312 100644 --- a/src/models/doubao.py +++ b/src/models/doubao.py @@ -4,6 +4,7 @@ import base64 from typing import Tuple, Optional from .base import VideoGenModel +from ..model_request_settings import MODEL_REQUEST_SETTINGS # Try to import Ark, handle if not installed (though user said they installed it) try: @@ -17,14 +18,17 @@ class DoubaoModel(VideoGenModel): def __init__(self, config: dict): super().__init__(config) self.api_key = os.getenv("ARK_API_KEY") - self.model_name = config.get('params', {}).get('model_name', 'doubao-seedance-1-0-pro-fast-251015') + self.model_name = config.get('params', {}).get( + 'model_name', MODEL_REQUEST_SETTINGS.doubao_model_name_default + ) + self.base_url = MODEL_REQUEST_SETTINGS.doubao_base_url if not self.api_key: logger.warning("ARK_API_KEY not found in environment variables.") if Ark: self.client = Ark( - base_url="https://ark.cn-beijing.volces.com/api/v3", + base_url=self.base_url, api_key=self.api_key ) else: diff --git a/src/models/image.py b/src/models/image.py index a30477ac..cfff2c9c 100644 --- a/src/models/image.py +++ b/src/models/image.py @@ -8,6 +8,7 @@ from dashscope import ImageSynthesis from ..utils import get_logger from ..utils.oss_utils import OSSImageUploader +from ..model_request_settings import MODEL_REQUEST_SETTINGS logger = get_logger(__name__) @@ -38,6 +39,14 @@ class WanxImageModel(ImageGenModel): def __init__(self, config): super().__init__(config) self.params = config.get('params', {}) + self.image_t2i_url = MODEL_REQUEST_SETTINGS.dashscope_image_t2i_url + self.image_i2i_url = MODEL_REQUEST_SETTINGS.dashscope_image_i2i_url + self.task_query_url_template = MODEL_REQUEST_SETTINGS.dashscope_task_query_url_template + self.default_t2i_model_name = MODEL_REQUEST_SETTINGS.wanx_image_t2i_model_name_default + self.default_i2i_model_name = MODEL_REQUEST_SETTINGS.wanx_image_i2i_model_name_default + self.http_t2i_models = set(MODEL_REQUEST_SETTINGS.wanx_image_http_t2i_model_names) + self.http_i2i_models = set(MODEL_REQUEST_SETTINGS.wanx_image_http_i2i_model_names) + self.four_ref_models = set(MODEL_REQUEST_SETTINGS.wanx_image_four_ref_models) @property def api_key(self): @@ -63,11 +72,11 @@ def generate(self, prompt: str, output_path: str, ref_image_path: str = None, re if model_name: final_model_name = model_name elif all_ref_paths: - # For I2I, use i2i_model_name if configured, otherwise default to wan2.5-i2i-preview - final_model_name = self.params.get('i2i_model_name', 'wan2.5-i2i-preview') + # For I2I, use i2i_model_name if configured, otherwise use configured default + final_model_name = self.params.get('i2i_model_name', self.default_i2i_model_name) else: - # For T2I, use model_name if configured, otherwise default to wan2.6-t2i - final_model_name = self.params.get('model_name', 'wan2.6-t2i') + # For T2I, use model_name if configured, otherwise use configured default + final_model_name = self.params.get('model_name', self.default_t2i_model_name) if all_ref_paths: logger.info(f"Using I2I model: {final_model_name} with {len(all_ref_paths)} reference images") @@ -81,7 +90,7 @@ def generate(self, prompt: str, output_path: str, ref_image_path: str = None, re kwargs.pop('model_name', None) # Determine reference image limit based on model - ref_limit = 4 if final_model_name == 'wan2.6-image' else 3 + ref_limit = 4 if final_model_name in self.four_ref_models else 3 if len(all_ref_paths) > ref_limit: logger.warning(f"Limiting reference images from {len(all_ref_paths)} to {ref_limit} for model {final_model_name}") all_ref_paths = all_ref_paths[:ref_limit] @@ -92,12 +101,13 @@ def generate(self, prompt: str, output_path: str, ref_image_path: str = None, re try: api_start_time = time.time() - # Use HTTP API for wan2.6 models (SDK not supported yet) - if final_model_name == 'wan2.6-t2i': - image_url = self._generate_wan26_http(prompt, size, n, negative_prompt) - elif final_model_name == 'wan2.6-image': - # wan2.6-image for I2I (requires reference images) - image_url = self._generate_wan26_image_http(prompt, size, n, negative_prompt, all_ref_paths) + # Use HTTP API for configured models (SDK not supported yet) + if final_model_name in self.http_t2i_models: + image_url = self._generate_wan26_http(prompt, size, n, negative_prompt, final_model_name) + elif final_model_name in self.http_i2i_models: + image_url = self._generate_wan26_image_http( + prompt, size, n, negative_prompt, all_ref_paths, final_model_name + ) else: # Use SDK for other models image_url = self._generate_sdk(prompt, final_model_name, size, n, negative_prompt, all_ref_paths, @@ -119,9 +129,12 @@ def generate(self, prompt: str, output_path: str, ref_image_path: str = None, re logger.error(traceback.format_exc()) raise - def _generate_wan26_http(self, prompt: str, size: str, n: int, negative_prompt: str = None) -> str: - """Generate image using Wan 2.6 T2I via HTTP API (synchronous).""" - url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/multimodal-generation/generation" + def _generate_wan26_http( + self, prompt: str, size: str, n: int, negative_prompt: str = None, model_name: str = None + ) -> str: + """Generate image using configured T2I HTTP API (synchronous).""" + model_name = model_name or self.default_t2i_model_name + url = self.image_t2i_url headers = { "Content-Type": "application/json", @@ -129,7 +142,7 @@ def _generate_wan26_http(self, prompt: str, size: str, n: int, negative_prompt: } payload = { - "model": "wan2.6-t2i", + "model": model_name, "input": { "messages": [ { @@ -154,7 +167,7 @@ def _generate_wan26_http(self, prompt: str, size: str, n: int, negative_prompt: if negative_prompt: payload["parameters"]["negative_prompt"] = negative_prompt - logger.info(f"Calling Wan 2.6 T2I HTTP API...") + logger.info(f"Calling {model_name} HTTP API...") logger.info(f"Payload: {payload}") response = requests.post(url, headers=headers, json=payload, timeout=300) # 5 minutes for slow API responses @@ -187,9 +200,12 @@ def _generate_wan26_http(self, prompt: str, size: str, n: int, negative_prompt: return image_url - def _generate_wan26_image_http(self, prompt: str, size: str, n: int, negative_prompt: str = None, ref_image_paths: list = None) -> str: - """Generate image using Wan 2.6 Image via HTTP API (asynchronous with polling).""" - create_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/image-generation/generation" + def _generate_wan26_image_http( + self, prompt: str, size: str, n: int, negative_prompt: str = None, ref_image_paths: list = None, model_name: str = None + ) -> str: + """Generate image using configured I2I HTTP API (asynchronous with polling).""" + model_name = model_name or self.default_i2i_model_name + create_url = self.image_i2i_url headers = { "Content-Type": "application/json", @@ -238,7 +254,7 @@ def _generate_wan26_image_http(self, prompt: str, size: str, n: int, negative_pr logger.warning(f"Reference image not found: {path}") payload = { - "model": "wan2.6-image", + "model": model_name, "input": { "messages": [ { @@ -260,7 +276,7 @@ def _generate_wan26_image_http(self, prompt: str, size: str, n: int, negative_pr if negative_prompt: payload["parameters"]["negative_prompt"] = negative_prompt - logger.info(f"Calling Wan 2.6 Image HTTP API (async)...") + logger.info(f"Calling {model_name} HTTP API (async)...") logger.info(f"Payload: {payload}") # Step 1: Create task @@ -272,7 +288,7 @@ def _generate_wan26_image_http(self, prompt: str, size: str, n: int, negative_pr if response.status_code != 200: error_data = response.json() if response.text else {} error_msg = error_data.get('message', response.text) - raise RuntimeError(f"Wan 2.6 Image task creation failed: {error_msg}") + raise RuntimeError(f"{model_name} task creation failed: {error_msg}") result = response.json() task_id = result.get('output', {}).get('task_id') @@ -282,7 +298,7 @@ def _generate_wan26_image_http(self, prompt: str, size: str, n: int, negative_pr logger.info(f"Task created: {task_id}") # Step 2: Poll for task completion - poll_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" + poll_url = self.task_query_url_template.format(task_id=task_id) poll_headers = { "Authorization": f"Bearer {self.api_key}" } @@ -337,15 +353,15 @@ def _generate_wan26_image_http(self, prompt: str, size: str, n: int, negative_pr 'Unknown error - check logs for full response' ) - raise RuntimeError(f"Wan 2.6 Image task failed: {error_msg}") + raise RuntimeError(f"{model_name} task failed: {error_msg}") elif task_status in ['CANCELED', 'UNKNOWN']: - raise RuntimeError(f"Wan 2.6 Image task {task_status}: {poll_result}") + raise RuntimeError(f"{model_name} task {task_status}: {poll_result}") # PENDING or RUNNING - continue polling - raise RuntimeError(f"Wan 2.6 Image task timed out after {max_wait_time}s") + raise RuntimeError(f"{model_name} task timed out after {max_wait_time}s") def _generate_sdk(self, prompt: str, model_name: str, size: str, n: int, negative_prompt: str, all_ref_paths: list, kwargs: dict) -> str: """Generate image using Dashscope SDK (for older models).""" @@ -402,7 +418,7 @@ def _generate_sdk(self, prompt: str, model_name: str, size: str, n: int, negativ logger.info(f"DEBUG: ref_image_urls count: {len(ref_image_urls)}") # Limit is already handled in generate(), but we keep a safety slice here - ref_limit = 4 if model_name == 'wan2.6-image' else 3 + ref_limit = 4 if model_name in self.four_ref_models else 3 if len(ref_image_urls) > ref_limit: logger.warning(f"Limiting reference images from {len(ref_image_urls)} to {ref_limit}") ref_image_urls = ref_image_urls[:ref_limit] diff --git a/src/models/kling.py b/src/models/kling.py index 43bb3f47..4e26ee21 100644 --- a/src/models/kling.py +++ b/src/models/kling.py @@ -4,6 +4,7 @@ import logging from typing import Dict, Any, Tuple from .base import VideoGenModel +from ..model_request_settings import MODEL_REQUEST_SETTINGS logger = logging.getLogger(__name__) @@ -11,8 +12,10 @@ class KlingModel(VideoGenModel): def __init__(self, config: Dict[str, Any]): super().__init__(config) self.api_key = config.get("api_key") - self.base_url = "https://api.klingai.com/v1" - self.model_name = config.get("params", {}).get("model_name", "kling-v2-5-turbo") + self.base_url = MODEL_REQUEST_SETTINGS.kling_base_url + self.model_name = config.get("params", {}).get( + "model_name", MODEL_REQUEST_SETTINGS.kling_model_name_default + ) def _get_token(self) -> str: """Generate JWT token for Kling API.""" @@ -101,4 +104,3 @@ def generate(self, prompt: str, output_path: str, img_url: str = None, **kwargs) except Exception as e: logger.error(f"Error polling Kling task: {e}") raise - diff --git a/src/models/qwen_vl.py b/src/models/qwen_vl.py index 8f654d05..76bdba9e 100644 --- a/src/models/qwen_vl.py +++ b/src/models/qwen_vl.py @@ -4,6 +4,7 @@ from typing import Tuple import dashscope +from ..model_request_settings import MODEL_REQUEST_SETTINGS logger = logging.getLogger(__name__) @@ -38,7 +39,9 @@ class QwenVLModel: def __init__(self, config: dict): - self.model_name = config.get('params', {}).get('model_name', 'qwen-vl-plus') + self.model_name = config.get('params', {}).get( + 'model_name', MODEL_REQUEST_SETTINGS.qwen_vl_model_name_default + ) @property def api_key(self): diff --git a/src/models/wanx.py b/src/models/wanx.py index 9e05143d..fa8c061f 100644 --- a/src/models/wanx.py +++ b/src/models/wanx.py @@ -6,6 +6,7 @@ import dashscope from .base import VideoGenModel from ..utils import get_logger +from ..model_request_settings import MODEL_REQUEST_SETTINGS from typing import Tuple @@ -19,6 +20,13 @@ def __init__(self, config): super().__init__(config) self.params = config.get('params', {}) + self.video_create_url = MODEL_REQUEST_SETTINGS.dashscope_video_create_url + self.task_query_url_template = MODEL_REQUEST_SETTINGS.dashscope_task_query_url_template + self.default_t2v_model_name = MODEL_REQUEST_SETTINGS.wanx_t2v_model_name_default + self.default_i2v_model_name = MODEL_REQUEST_SETTINGS.wanx_i2v_model_name_default + self.default_r2v_model_name = MODEL_REQUEST_SETTINGS.wanx_r2v_model_name_default + self.http_i2v_models = set(MODEL_REQUEST_SETTINGS.wanx_http_i2v_model_names) + self.http_r2v_models = set(MODEL_REQUEST_SETTINGS.wanx_http_r2v_model_names) @property def api_key(self): @@ -36,10 +44,10 @@ def generate(self, prompt: str, output_path: str, img_path: str = None, model_na final_model_name = kwargs.get('model') logger.info(f"Using model from kwargs: {final_model_name}") elif img_path or kwargs.get('img_url'): - final_model_name = self.params.get('i2v_model_name', 'wan2.6-i2v') # Default to I2V model + final_model_name = self.params.get('i2v_model_name', self.default_i2v_model_name) logger.info(f"Using I2V model: {final_model_name}") else: - final_model_name = self.params.get('model_name', 'wan2.5-t2v-preview') + final_model_name = self.params.get('model_name', self.default_t2v_model_name) logger.info(f"Using T2V model: {final_model_name}") size = self.params.get('size', '1280*720') @@ -110,8 +118,8 @@ def generate(self, prompt: str, output_path: str, img_path: str = None, model_na else: logger.warning(f"OSS not configured, cannot sign Object Key in img_url: {img_url}") - # Use HTTP API for wan2.6-i2v, wan2.5-i2v, or wan2.6-r2v - if final_model_name in ['wan2.6-i2v', 'wan2.5-i2v']: + # Use HTTP API for configured I2V/R2V models + if final_model_name in self.http_i2v_models: # Get shot_type from kwargs (only for wan I2V models) shot_type = kwargs.get('shot_type', 'single') video_url = self._generate_wan_i2v_http( @@ -127,11 +135,11 @@ def generate(self, prompt: str, output_path: str, img_path: str = None, model_na seed=seed, shot_type=shot_type ) - elif final_model_name == 'wan2.6-r2v': + elif final_model_name in self.http_r2v_models: # R2V generation ref_video_urls = kwargs.get('ref_video_urls', []) if not ref_video_urls: - raise ValueError("ref_video_urls is required for wan2.6-r2v") + raise ValueError(f"ref_video_urls is required for {final_model_name}") # Process ref_video_urls: Upload local files or sign Object Keys processed_ref_urls = [] @@ -217,14 +225,15 @@ def generate(self, prompt: str, output_path: str, img_path: str = None, model_na logger.error(f"Error during generation: {e}") raise - def _generate_wan_i2v_http(self, prompt: str, img_url: str, model_name: str = "wan2.6-i2v", + def _generate_wan_i2v_http(self, prompt: str, img_url: str, model_name: str = None, resolution: str = "720P", duration: int = 5, prompt_extend: bool = True, negative_prompt: str = None, audio_url: str = None, watermark: bool = False, seed: int = None, shot_type: str = "single") -> str: """Generate video using Wan I2V (2.5 or 2.6) via HTTP API (asynchronous with polling).""" - create_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis" + model_name = model_name or self.default_i2v_model_name + create_url = self.video_create_url headers = { "Content-Type": "application/json", @@ -279,7 +288,7 @@ def _generate_wan_i2v_http(self, prompt: str, img_url: str, model_name: str = "w logger.info(f"Task created: {task_id}") # Step 2: Poll for task completion - poll_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" + poll_url = self.task_query_url_template.format(task_id=task_id) poll_headers = { "Authorization": f"Bearer {self.api_key}" } @@ -323,12 +332,13 @@ def _generate_wan_i2v_http(self, prompt: str, img_url: str, model_name: str = "w raise RuntimeError(f"{model_name} task timed out after {max_wait_time}s") - def _generate_wan_r2v_http(self, prompt: str, ref_video_urls: list, model_name: str = "wan2.6-r2v", + def _generate_wan_r2v_http(self, prompt: str, ref_video_urls: list, model_name: str = None, size: str = "1280*720", duration: int = 5, audio: bool = True, shot_type: str = "multi", seed: int = None) -> str: """Generate video using Wan R2V via HTTP API (asynchronous with polling).""" - create_url = "https://dashscope.aliyuncs.com/api/v1/services/aigc/video-generation/video-synthesis" + model_name = model_name or self.default_r2v_model_name + create_url = self.video_create_url headers = { "Content-Type": "application/json", @@ -375,7 +385,7 @@ def _generate_wan_r2v_http(self, prompt: str, ref_video_urls: list, model_name: logger.info(f"Task created: {task_id}") # Step 2: Poll for task completion - poll_url = f"https://dashscope.aliyuncs.com/api/v1/tasks/{task_id}" + poll_url = self.task_query_url_template.format(task_id=task_id) poll_headers = { "Authorization": f"Bearer {self.api_key}" }