|
| 1 | +# -*- coding: utf-8 -*- |
| 2 | +# pylint:disable=abstract-method, deprecated-module, wrong-import-order |
| 3 | +# pylint:disable=too-many-nested-blocks, too-many-branches, unused-import |
| 4 | + |
| 5 | +import os |
| 6 | +import uuid |
| 7 | +from typing import Any, Optional, List |
| 8 | + |
| 9 | +from dashscope import AioMultiModalConversation |
| 10 | +from mcp.server.fastmcp import Context |
| 11 | +from pydantic import BaseModel, Field |
| 12 | + |
| 13 | +from ..base import Tool |
| 14 | +from ..utils.api_key_util import get_api_key, ApiNames |
| 15 | +from ...engine.tracing import trace, TracingUtil |
| 16 | + |
| 17 | + |
| 18 | +class QwenImageEditInputNew(BaseModel): |
| 19 | + """ |
| 20 | + Qwen Image Edit Input New(Supports single or multiple images for fusion) |
| 21 | + """ |
| 22 | + |
| 23 | + image_urls: list[str] = Field( |
| 24 | + ..., |
| 25 | + description="输入图像的URL地址列表,每个URL需为公网可访问地址,支持 HTTP 或 " |
| 26 | + "HTTPS 协议。格式:JPG、JPEG、PNG、BMP、TIFF、WEBP,分辨率[384," |
| 27 | + "3072],大小不超过10MB。URL不能包含中文字符。", |
| 28 | + ) |
| 29 | + prompt: str = Field( |
| 30 | + ..., |
| 31 | + description="正向提示词,用来描述生成图像中期望包含的元素和视觉特点, 超过800个字符自动截断", |
| 32 | + ) |
| 33 | + negative_prompt: Optional[str] = Field( |
| 34 | + default=None, |
| 35 | + description="反向提示词,用来描述不希望在画面中看到的内容,可以对画面进行限制,超过500个字符自动截断", |
| 36 | + ) |
| 37 | + watermark: Optional[bool] = Field( |
| 38 | + default=None, |
| 39 | + description="是否添加水印,默认不设置", |
| 40 | + ) |
| 41 | + ctx: Optional[Context] = Field( |
| 42 | + default=None, |
| 43 | + description="HTTP request context containing headers for mcp only, " |
| 44 | + "don't generate it", |
| 45 | + ) |
| 46 | + |
| 47 | + |
| 48 | +class QwenImageEditOutputNew(BaseModel): |
| 49 | + """ |
| 50 | + Qwen Image Edit Output New |
| 51 | + """ |
| 52 | + |
| 53 | + results: list[str] = Field( |
| 54 | + title="Results", |
| 55 | + description="输出的融合后图片url列表,仅包含1个URL", |
| 56 | + ) |
| 57 | + request_id: Optional[str] = Field( |
| 58 | + default=None, |
| 59 | + title="Request ID", |
| 60 | + description="请求ID", |
| 61 | + ) |
| 62 | + |
| 63 | + |
| 64 | +class QwenImageEditNew(Tool[QwenImageEditInputNew, QwenImageEditOutputNew]): |
| 65 | + """ |
| 66 | + Qwen Image Edit Tool for AI-powered image editing. |
| 67 | + Supports single or multiple images for fusion. |
| 68 | + """ |
| 69 | + |
| 70 | + name: str = "modelstudio_qwen_image_edit_new" |
| 71 | + description: str = ( |
| 72 | + "通义千问-多图融合模型,基于 qwen-image-edit,支持将多张图像按提示词语义融合为一张新图。" |
| 73 | + "可用于风格混合、场景合成、元素组合等复杂图像生成任务。" |
| 74 | + ) |
| 75 | + |
| 76 | + @trace(trace_type="AIGC", trace_name="qwen_image_edit_new") |
| 77 | + async def arun( |
| 78 | + self, |
| 79 | + args: QwenImageEditInputNew, |
| 80 | + **kwargs: Any, |
| 81 | + ) -> QwenImageEditOutputNew: |
| 82 | + """Qwen Image Edit using MultiModalConversation API |
| 83 | +
|
| 84 | + This method uses DashScope's MultiModalConversation service to edit |
| 85 | + images based on text prompts. The API supports various image editing |
| 86 | + operations through natural language instructions. |
| 87 | +
|
| 88 | + Args: |
| 89 | + args: QwenImageEditInputNew containing image_urls, text_prompt, |
| 90 | + watermark, and negative_prompt. |
| 91 | + **kwargs: Additional keyword arguments including request_id, |
| 92 | + trace_event, model_name, api_key. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + QwenImageEditOutputNew containing |
| 96 | + the edited image URL and request ID. |
| 97 | +
|
| 98 | + Raises: |
| 99 | + ValueError: If DASHSCOPE_API_KEY is not set or invalid. |
| 100 | + RuntimeError: If the API call fails or returns an error. |
| 101 | + """ |
| 102 | + |
| 103 | + trace_event = kwargs.pop("trace_event", None) |
| 104 | + request_id = TracingUtil.get_request_id() |
| 105 | + |
| 106 | + try: |
| 107 | + api_key = get_api_key(ApiNames.dashscope_api_key, **kwargs) |
| 108 | + except AssertionError as e: |
| 109 | + raise ValueError("Please set valid DASHSCOPE_API_KEY!") from e |
| 110 | + |
| 111 | + model_name = kwargs.get( |
| 112 | + "model_name", |
| 113 | + os.getenv("QWEN_IMAGE_EDIT_MODEL_NAME", "qwen-image-edit"), |
| 114 | + ) |
| 115 | + |
| 116 | + # Prepare messages in the format expected by MultiModalConversation |
| 117 | + content = [{"image": url} for url in args.image_urls] |
| 118 | + content.append({"text": args.prompt}) |
| 119 | + |
| 120 | + messages = [ |
| 121 | + { |
| 122 | + "role": "user", |
| 123 | + "content": content, |
| 124 | + }, |
| 125 | + ] |
| 126 | + |
| 127 | + # 标准化 watermark 输入为布尔值 |
| 128 | + if args.watermark is not None: |
| 129 | + if isinstance(args.watermark, str): |
| 130 | + args.watermark = args.watermark.strip().lower() in ( |
| 131 | + "true", |
| 132 | + "1", |
| 133 | + ) |
| 134 | + else: |
| 135 | + args.watermark = bool(args.watermark) |
| 136 | + |
| 137 | + parameters = {} |
| 138 | + if args.negative_prompt: |
| 139 | + parameters["negative_prompt"] = args.negative_prompt |
| 140 | + if args.watermark is not None: |
| 141 | + parameters["watermark"] = args.watermark |
| 142 | + |
| 143 | + # Call the AioMultiModalConversation API asynchronously |
| 144 | + try: |
| 145 | + response = await AioMultiModalConversation.call( |
| 146 | + api_key=api_key, |
| 147 | + model=model_name, |
| 148 | + messages=messages, |
| 149 | + **parameters, |
| 150 | + ) |
| 151 | + except Exception as e: |
| 152 | + raise RuntimeError( |
| 153 | + f"Failed to call Qwen Image Edit API: {str(e)}", |
| 154 | + ) from e |
| 155 | + |
| 156 | + # Check response status |
| 157 | + if response.status_code != 200 or not response.output: |
| 158 | + raise RuntimeError(f"Failed to generate: {response}") |
| 159 | + |
| 160 | + # Extract the edited image URLs from response |
| 161 | + try: |
| 162 | + results = [] |
| 163 | + |
| 164 | + # Try to get from output.choices[0].message.content |
| 165 | + if hasattr(response, "output") and response.output: |
| 166 | + choices = getattr(response.output, "choices", []) |
| 167 | + if choices: |
| 168 | + message = getattr(choices[0], "message", {}) |
| 169 | + if hasattr(message, "content"): |
| 170 | + content = message.content |
| 171 | + if isinstance(content, list): |
| 172 | + # Look for image content in the list |
| 173 | + for item in content: |
| 174 | + if isinstance(item, dict) and "image" in item: |
| 175 | + results.append(item["image"]) |
| 176 | + elif isinstance(content, str): |
| 177 | + results.append(content) |
| 178 | + elif isinstance(content, dict) and "image" in content: |
| 179 | + results.append(content["image"]) |
| 180 | + |
| 181 | + if not results: |
| 182 | + raise RuntimeError( |
| 183 | + f"Could not extract edited image URLs from response: " |
| 184 | + f" {response}", |
| 185 | + ) |
| 186 | + |
| 187 | + except Exception as e: |
| 188 | + raise RuntimeError( |
| 189 | + f"Failed to parse response from Qwen Image Edit API: {str(e)}", |
| 190 | + ) from e |
| 191 | + |
| 192 | + # Get request ID |
| 193 | + if request_id == "": |
| 194 | + request_id = getattr(response, "request_id", None) or str( |
| 195 | + uuid.uuid4(), |
| 196 | + ) |
| 197 | + |
| 198 | + # Log trace event if provided |
| 199 | + if trace_event: |
| 200 | + trace_event.on_log( |
| 201 | + "", |
| 202 | + **{ |
| 203 | + "step_suffix": "results", |
| 204 | + "payload": { |
| 205 | + "request_id": request_id, |
| 206 | + "qwen_image_edit_result": { |
| 207 | + "status_code": response.status_code, |
| 208 | + "results": results, |
| 209 | + }, |
| 210 | + }, |
| 211 | + }, |
| 212 | + ) |
| 213 | + |
| 214 | + return QwenImageEditOutputNew( |
| 215 | + results=results, |
| 216 | + request_id=request_id, |
| 217 | + ) |
0 commit comments