|
| 1 | +import os |
| 2 | +import aiohttp |
| 3 | +import google.generativeai as genai |
| 4 | + |
| 5 | +from io import BytesIO |
| 6 | +from PIL import Image as PILImage |
| 7 | +from nonebot.typing import T_State |
| 8 | +from nonebot.matcher import Matcher |
| 9 | +from nonebot.plugin import PluginMetadata |
| 10 | +from nonebot.adapters import Message, Event, Bot |
| 11 | +from nonebot import require, get_driver, on_command |
| 12 | +from nonebot.params import CommandArg, ArgPlainText |
| 13 | +from google.generativeai.generative_models import ChatSession |
| 14 | + |
| 15 | +from .config import Config |
| 16 | + |
| 17 | +require("nonebot_plugin_alconna") |
| 18 | +require("nonebot_plugin_htmlrender") |
| 19 | + |
| 20 | +from nonebot_plugin_alconna import UniMessage, Text, Image |
| 21 | +from nonebot_plugin_htmlrender import md_to_pic |
| 22 | + |
| 23 | + |
| 24 | +__plugin_meta__ = PluginMetadata( |
| 25 | + name="nonebot-plugin-gemini", |
| 26 | + description="Gemini AI 对话", |
| 27 | + usage="gemini [文本/图片] -Gemini 生成回复\ngeminichat (可选)[文本] -开始 Gemini 对话\n结束对话 -结束 Gemini 对话", |
| 28 | + type="application", |
| 29 | + homepage="https://github.com/zhaomaoniu/nonebot-plugin-gemini", |
| 30 | + config=Config, |
| 31 | + supported_adapters=None, |
| 32 | +) |
| 33 | + |
| 34 | + |
| 35 | +plugin_config = Config.parse_obj(get_driver().config) |
| 36 | + |
| 37 | + |
| 38 | +GOOGLE_API_KEY = plugin_config.google_api_key or os.environ.get("GOOGLE_API_KEY", None) |
| 39 | + |
| 40 | + |
| 41 | +if GOOGLE_API_KEY is None: |
| 42 | + raise ValueError("GOOGLE_API_KEY 未配置, nonebot-plugin-gemini 无法运行") |
| 43 | + |
| 44 | + |
| 45 | +genai.configure(api_key=GOOGLE_API_KEY) |
| 46 | + |
| 47 | +models = { |
| 48 | + "gemini-pro": genai.GenerativeModel("gemini-pro"), |
| 49 | + "gemini-pro-vision": genai.GenerativeModel("gemini-pro-vision"), |
| 50 | +} |
| 51 | + |
| 52 | + |
| 53 | +async def to_markdown(text: str) -> bytes: |
| 54 | + text = text.replace("•", " *") |
| 55 | + return await md_to_pic(text, width=800) |
| 56 | + |
| 57 | + |
| 58 | +async def to_pil_image(image: Image) -> PILImage: |
| 59 | + if image.raw is not None: |
| 60 | + return PILImage.open( |
| 61 | + image.raw.getvalue() if isinstance(image.raw, BytesIO) else image.raw |
| 62 | + ) |
| 63 | + |
| 64 | + try: |
| 65 | + return PILImage.open(image.raw_bytes) |
| 66 | + except ValueError: |
| 67 | + pass |
| 68 | + |
| 69 | + if image.path is not None: |
| 70 | + return PILImage.open(image.path) |
| 71 | + |
| 72 | + if image.url is not None: |
| 73 | + async with aiohttp.ClientSession() as session: |
| 74 | + async with session.get(image.url) as resp: |
| 75 | + data = await resp.read() |
| 76 | + return PILImage.open(BytesIO(data)) |
| 77 | + |
| 78 | + raise ValueError("无法获取图片") |
| 79 | + |
| 80 | + |
| 81 | +chat = on_command("gemini", priority=10, block=True) |
| 82 | +conversation = on_command("geminichat", priority=5, block=True) |
| 83 | + |
| 84 | + |
| 85 | +@chat.handle() |
| 86 | +async def _(event: Event, bot: Bot, message: Message = CommandArg()): |
| 87 | + uni_message = await UniMessage.generate(message=message, event=event, bot=bot) |
| 88 | + |
| 89 | + msg = [] |
| 90 | + model = "gemini-pro" |
| 91 | + |
| 92 | + for seg in uni_message: |
| 93 | + if isinstance(seg, Text): |
| 94 | + msg.append(seg.text) |
| 95 | + |
| 96 | + elif isinstance(seg, Image): |
| 97 | + model = "gemini-pro-vision" |
| 98 | + msg.append(await to_pil_image(seg)) |
| 99 | + |
| 100 | + if not msg: |
| 101 | + await chat.finish("未获取到有效输入,输入应为文本或图片") |
| 102 | + |
| 103 | + try: |
| 104 | + resp = await models[model].generate_content_async(msg) |
| 105 | + except Exception as e: |
| 106 | + await chat.finish(f"{type(e).__name__}: {e}") |
| 107 | + |
| 108 | + try: |
| 109 | + result = resp.text |
| 110 | + except ValueError: |
| 111 | + result = "\n---\n".join( |
| 112 | + [part.text for part in resp.candidates[0].content.parts] |
| 113 | + ) |
| 114 | + |
| 115 | + await chat.finish( |
| 116 | + await UniMessage(Image(raw=await to_markdown(result))).export() |
| 117 | + if len(result) > 500 |
| 118 | + else result.strip() |
| 119 | + ) |
| 120 | + |
| 121 | + |
| 122 | +@conversation.handle() |
| 123 | +async def start_conversation( |
| 124 | + state: T_State, matcher: Matcher, args: Message = CommandArg() |
| 125 | +): |
| 126 | + if args.extract_plain_text() != "": |
| 127 | + matcher.set_arg(key="msg", message=args) |
| 128 | + |
| 129 | + state["gemini_chat_session"] = models["gemini-pro"].start_chat(history=[]) |
| 130 | + |
| 131 | + |
| 132 | +@conversation.got("msg", prompt="对话开始") |
| 133 | +async def got_message(state: T_State, msg: str = ArgPlainText()): |
| 134 | + if msg in ["结束", "结束对话", "结束会话", "stop", "quit"]: |
| 135 | + await conversation.finish("对话结束") |
| 136 | + |
| 137 | + chat_session: ChatSession = state["gemini_chat_session"] |
| 138 | + |
| 139 | + try: |
| 140 | + resp = await chat_session.send_message_async(msg) |
| 141 | + except Exception as e: |
| 142 | + await conversation.finish(f"发生意外错误,对话已结束\n{type(e).__name__}: {e}") |
| 143 | + |
| 144 | + try: |
| 145 | + result = resp.text |
| 146 | + except ValueError: |
| 147 | + result = "\n---\n".join( |
| 148 | + [part.text for part in resp.candidates[0].content.parts] |
| 149 | + ) |
| 150 | + |
| 151 | + await conversation.reject( |
| 152 | + await UniMessage(Image(raw=await to_markdown(result))).export() |
| 153 | + if len(result) > 500 |
| 154 | + else result.strip() |
| 155 | + ) |
0 commit comments