diff --git a/python/rtclient/low_level_client.py b/python/rtclient/low_level_client.py index f023270..8c36d36 100644 --- a/python/rtclient/low_level_client.py +++ b/python/rtclient/low_level_client.py @@ -3,15 +3,21 @@ import json import uuid +import logging +import asyncio +import random from collections.abc import AsyncIterator -from typing import Any, Optional +from typing import Any, Optional, Callable -from aiohttp import ClientSession, WSMsgType, WSServerHandshakeError +from aiohttp import ClientSession, WSMsgType, WSServerHandshakeError, ClientConnectorError from rtclient.models import ServerMessageType, UserMessageType, create_message_from_dict from rtclient.util.user_agent import get_user_agent +logger = logging.getLogger(__name__) + + class ConnectionError(Exception): def __init__(self, message: str, headers=None): super().__init__(message) @@ -24,6 +30,11 @@ def __init__( url: str, headers: Optional[dict[str, str]] = None, params: Optional[dict[str, Any]] = None, + max_retries: int = 5, + initial_retry_delay: float = 1.0, + max_retry_delay: float = 30.0, + retry_jitter: float = 0.1, + on_reconnect: Optional[Callable[[], None]] = None, ): """初始化WebSocket客户端 @@ -31,16 +42,40 @@ def __init__( url: WebSocket服务器地址 headers: 请求头 params: URL参数 + max_retries: 最大重试次数, 设置为0表示不重试, 设置为-1表示无限重试 + initial_retry_delay: 初始重试延迟(秒) + max_retry_delay: 最大重试延迟(秒) + retry_jitter: 随机波动因子(0-1之间), 用于避免重连风暴 + on_reconnect: 重连成功后的回调函数 """ self._url = url self._headers = headers or {} self._params = params or {} - self._session = ClientSession() + self._session = None self.request_id: Optional[uuid.UUID] = None self.ws = None - + + # 重连参数 + self._max_retries = max_retries + self._initial_retry_delay = initial_retry_delay + self._max_retry_delay = max_retry_delay + self._retry_jitter = retry_jitter + self._on_reconnect = on_reconnect + self._reconnecting = False + self._retry_count = 0 + self._should_reconnect = True + async def connect(self): """连接到WebSocket服务器""" + if self._session is None: + self._session = ClientSession() + + self._retry_count = 0 + self._should_reconnect = True + await self._do_connect() + + async def _do_connect(self): + """执行实际的连接逻辑""" try: self.request_id = uuid.uuid4() headers = { @@ -52,22 +87,121 @@ async def connect(self): headers=headers, params=self._params ) - except WSServerHandshakeError as e: - await self._session.close() - error_message = f"连接服务器失败,状态码: {e.status}" - raise ConnectionError(error_message, e.headers) from e + logger.info("WebSocket连接成功") + self._retry_count = 0 # 连接成功后重置重试计数 + self._reconnecting = False + + # 如果这是重连成功,则调用回调 + if self._retry_count > 0 and self._on_reconnect: + self._on_reconnect() + + except (WSServerHandshakeError, ClientConnectorError) as e: + error_type = "握手" if isinstance(e, WSServerHandshakeError) else "连接" + status = getattr(e, 'status', 'unknown') + error_message = f"WebSocket{error_type}失败,状态码: {status}" + logger.error(error_message) + + if not await self._handle_connection_failure(e): + if self._session: + await self._session.close() + headers = getattr(e, 'headers', None) + raise ConnectionError(error_message, headers) from e + async def _handle_connection_failure(self, exception) -> bool: + """处理连接失败的情况,尝试重连 + + Returns: + bool: 如果将继续重连则返回True,否则返回False + """ + # 检查是否需要重连 + if self._max_retries == 0 or (self._max_retries > 0 and self._retry_count >= self._max_retries): + logger.warning(f"达到最大重试次数 {self._max_retries},停止重连") + return False + + if not self._should_reconnect: + logger.info("重连已被禁用,不再尝试重连") + return False + + if self._reconnecting: + logger.debug("已经在重连中,跳过重连请求") + return True + + self._reconnecting = True + self._retry_count += 1 + + # 计算指数退避延迟时间(带随机抖动) + delay = min(self._initial_retry_delay * (2 ** (self._retry_count - 1)), self._max_retry_delay) + jitter = random.uniform(-self._retry_jitter, self._retry_jitter) + adjusted_delay = max(0.1, delay * (1 + jitter)) + + logger.info(f"尝试第 {self._retry_count} 次重连,等待 {adjusted_delay:.2f} 秒") + + # 异步等待后重连 + try: + await asyncio.sleep(adjusted_delay) + await self._do_connect() + return True + except Exception as e: + logger.error(f"重连过程中发生错误: {e}") + self._reconnecting = False + return True # 继续让上层处理重连逻辑 + + async def reconnect(self) -> bool: + """手动触发重连 + + Returns: + bool: 重连是否成功 + """ + if self._reconnecting: + logger.warning("已在重连过程中,忽略重连请求") + return False + + logger.info("手动触发重连") + self._reconnecting = True + + try: + # 确保旧连接已关闭 + if self.ws and not self.ws.closed: + await self.ws.close() + + await self._do_connect() + return not self.closed + except Exception as e: + logger.error(f"手动重连失败: {e}") + return False + finally: + self._reconnecting = False + async def send(self, message: UserMessageType | dict[str, Any]): """发送消息到服务器 Args: message: 要发送的消息,可以是 UserMessageType 或 dict """ - if hasattr(message, 'model_dump_json'): - message_data = message.model_dump_json() - else: - message_data = json.dumps(message) - await self.ws.send_str(message_data) + if self.ws is None or self.ws.closed: + logger.error("WebSocket连接已关闭,无法发送消息") + await self.reconnect() + if self.ws is None or self.ws.closed: + raise ConnectionError("WebSocket连接已关闭且重连失败,无法发送消息") + + try: + if hasattr(message, 'model_dump_json'): + message_data = message.model_dump_json() + else: + message_data = json.dumps(message) + await self.ws.send_str(message_data) + except Exception as e: + logger.error(f"发送消息失败: {e}") + # 尝试重连并重新发送 + if await self.reconnect(): + # 重连成功,重试发送 + if hasattr(message, 'model_dump_json'): + message_data = message.model_dump_json() + else: + message_data = json.dumps(message) + await self.ws.send_str(message_data) + else: + raise async def send_json(self, message: dict[str, Any]): """发送JSON消息到服务器 @@ -75,7 +209,21 @@ async def send_json(self, message: dict[str, Any]): Args: message: 要发送的JSON消息 """ - await self.ws.send_json(message) + if self.ws is None or self.ws.closed: + logger.error("WebSocket连接已关闭,无法发送JSON消息") + await self.reconnect() + if self.ws is None or self.ws.closed: + raise ConnectionError("WebSocket连接已关闭且重连失败,无法发送JSON消息") + + try: + await self.ws.send_json(message) + except Exception as e: + logger.error(f"发送JSON消息失败: {e}") + # 尝试重连并重新发送 + if await self.reconnect(): + await self.ws.send_json(message) + else: + raise async def recv(self) -> Optional[ServerMessageType]: """接收服务器消息 @@ -83,14 +231,32 @@ async def recv(self) -> Optional[ServerMessageType]: Returns: 接收到的消息对象 """ - if self.ws.closed: - return None - websocket_message = await self.ws.receive() - if websocket_message.type == WSMsgType.TEXT: - data = json.loads(websocket_message.data) - msg = create_message_from_dict(data) - return msg - else: + if self.ws is None or self.ws.closed: + logger.error("WebSocket连接已关闭,无法接收消息") + await self.reconnect() + if self.ws is None or self.ws.closed: + return None + + try: + websocket_message = await self.ws.receive() + + if websocket_message.type == WSMsgType.TEXT: + data = json.loads(websocket_message.data) + msg = create_message_from_dict(data) + return msg + elif websocket_message.type == WSMsgType.CLOSED: + logger.warning("服务器关闭了WebSocket连接") + await self.reconnect() + return None + elif websocket_message.type == WSMsgType.ERROR: + logger.error(f"WebSocket连接错误: {websocket_message.data}") + await self.reconnect() + return None + else: + return None + except Exception as e: + logger.error(f"接收消息时发生错误: {e}") + await self.reconnect() return None def __aiter__(self) -> AsyncIterator[ServerMessageType]: @@ -104,9 +270,20 @@ async def __anext__(self): async def close(self): """关闭连接""" + self._should_reconnect = False # 禁用重连 + if self.ws: - await self.ws.close() - await self._session.close() + try: + await self.ws.close() + except Exception as e: + logger.warning(f"关闭WebSocket连接时发生错误: {e}") + + if self._session: + try: + await self._session.close() + self._session = None + except Exception as e: + logger.warning(f"关闭HTTP会话时发生错误: {e}") @property def closed(self) -> bool: diff --git a/python/rtclient/util/user_agent.py b/python/rtclient/util/user_agent.py index 7854e27..ddfba3e 100644 --- a/python/rtclient/util/user_agent.py +++ b/python/rtclient/util/user_agent.py @@ -2,10 +2,17 @@ # Licensed under the MIT license. import platform -from importlib.metadata import version +from importlib.metadata import version, PackageNotFoundError def get_user_agent(): - package_version = version("rtclient") + try: + package_version = version("rtclient") + except PackageNotFoundError: + package_version = "dev" + python_version = platform.python_version() - return f"zhipu-rtclient/{package_version} Python/{python_version}" + system = platform.system() + architecture = platform.machine() + + return f"zhipu-rtclient/{package_version} Python/{python_version} {system}/{architecture}" diff --git a/python/samples/low_level_sample_with_reconnect.py b/python/samples/low_level_sample_with_reconnect.py new file mode 100644 index 0000000..4cef448 --- /dev/null +++ b/python/samples/low_level_sample_with_reconnect.py @@ -0,0 +1,275 @@ +# Copyright (c) ZhiPu Corporation. +# Licensed under the MIT license. + +import asyncio +import base64 +import os +import signal +import sys +import wave +import logging +import random +from io import BytesIO +from typing import Optional + +from dotenv import load_dotenv +from message_handler import create_message_handler + +from rtclient import RTLowLevelClient +from rtclient.models import ( + ClientVAD, + InputAudioBufferAppendMessage, + InputAudioBufferCommitMessage, + SessionUpdateMessage, + SessionUpdateParams, +) + +# 设置日志 +logging.basicConfig( + level=logging.INFO, + format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger("reconnect_demo") + +shutdown_event: Optional[asyncio.Event] = None + + +def handle_shutdown(sig=None, frame=None): + """处理关闭信号""" + if shutdown_event: + logger.info("正在关闭程序...") + shutdown_event.set() + + +def encode_wave_to_base64(wave_file_path): + """ + 将WAV文件转换为base64编码,确保生成标准的WAV格式 + Args: + wave_file_path: WAV文件路径 + Returns: + base64编码的字符串 + """ + try: + with wave.open(wave_file_path, "rb") as wave_file: + # 获取音频参数 + channels = wave_file.getnchannels() + sample_width = wave_file.getsampwidth() + frame_rate = wave_file.getframerate() + frames = wave_file.readframes(wave_file.getnframes()) + + # 验证音频参数是否合法 + if channels < 1 or sample_width < 1 or frame_rate <= 0: + logger.error(f"无效的音频参数: channels={channels}, sample_width={sample_width}, frame_rate={frame_rate}") + return None + + # 创建字节流并写入标准WAV格式 + wave_io = BytesIO() + with wave.open(wave_io, "wb") as wave_out: + # 设置WAV文件头部信息 + wave_out.setnchannels(channels) + wave_out.setsampwidth(sample_width) # 位深度 (1 = 8位, 2 = 16位, etc.) + wave_out.setframerate(frame_rate) # 采样率 (常见值: 44100, 48000) + # 写入音频数据 + wave_out.writeframes(frames) + + # 确保写入完整的WAV文件数据 + wave_io.seek(0) + + # 获取字节数据并编码为base64 + logger.info(f"音频参数: 声道数={channels}, 位深度={sample_width*8}位, 采样率={frame_rate}Hz") + return base64.b64encode(wave_io.getvalue()).decode("utf-8") + except Exception as e: + logger.error(f"音频文件处理错误: {str(e)}") + return None + + +async def send_audio(client: RTLowLevelClient, audio_file_path: str): + """发送音频""" + base64_content = encode_wave_to_base64(audio_file_path) + if base64_content is None: + logger.error("音频编码失败") + return + + # 验证音频数据长度 + if len(base64_content) == 0: + logger.error("音频数据为空") + return + + # 发送音频数据 + audio_message = InputAudioBufferAppendMessage( + audio=base64_content, client_timestamp=int(asyncio.get_event_loop().time() * 1000) + ) + await client.send(audio_message) + + +def get_env_var(var_name: str) -> str: + value = os.environ.get(var_name) + if not value: + raise OSError(f"环境变量 '{var_name}' 未设置或为空。") + return value + + +class ConnectionBreaker: + """模拟网络连接中断的工具类""" + def __init__(self, client: RTLowLevelClient): + self.client = client + self.break_task = None + self.enabled = False + + def start(self, min_interval=10, max_interval=30): + """开始模拟随机断开连接""" + self.enabled = True + self.break_task = asyncio.create_task(self._run_breaker(min_interval, max_interval)) + + def stop(self): + """停止模拟随机断开连接""" + self.enabled = False + if self.break_task: + self.break_task.cancel() + + async def _run_breaker(self, min_interval, max_interval): + """执行随机断开连接的循环""" + while self.enabled: + # 随机等待一段时间 + interval = random.uniform(min_interval, max_interval) + logger.info(f"计划在 {interval:.2f} 秒后模拟连接中断") + await asyncio.sleep(interval) + + if not self.enabled: + break + + # 模拟连接中断 + logger.warning("模拟连接中断...") + if hasattr(self.client, 'ws') and self.client.ws: + try: + # 强制关闭连接但不触发正常关闭流程 + await self.client.ws.close(code=1006, message=b"Simulated network interruption") + except Exception as e: + logger.error(f"模拟中断时出错: {e}") + + +async def on_reconnect(): + """重连成功的回调函数""" + logger.info("重连成功,重新初始化会话状态...") + + +async def with_zhipu(audio_file_path: str, simulate_breaks=True): + global shutdown_event + shutdown_event = asyncio.Event() + + for sig in (signal.SIGINT, signal.SIGTERM): + signal.signal(sig, handle_shutdown) + + api_key = get_env_var("ZHIPU_API_KEY") + try: + # 使用新的重连参数创建客户端 + client = RTLowLevelClient( + url="wss://open.bigmodel.cn/api/paas/v4/realtime", + headers={"Authorization": f"Bearer {api_key}"}, + max_retries=5, # 最多重试5次 + initial_retry_delay=1.0, # 初始1秒延迟 + max_retry_delay=15.0, # 最大15秒延迟 + retry_jitter=0.2, # 20%的随机抖动 + on_reconnect=on_reconnect # 重连成功后的回调 + ) + + # 创建连接断开模拟器 + connection_breaker = ConnectionBreaker(client) + + try: + await client.connect() + + # 如果需要,启动连接断开模拟器 + if simulate_breaks: + connection_breaker.start(min_interval=5, max_interval=15) # 5-15秒随机断开 + logger.info("已启动连接断开模拟器") + + # 发送会话配置 + if shutdown_event.is_set(): + return + + session_message = SessionUpdateMessage( + session=SessionUpdateParams( + input_audio_format="wav", + output_audio_format="pcm", + modalities={"audio", "text"}, + turn_detection=ClientVAD(), + beta_fields={"chat_mode": "audio", "tts_source": "e2e", "auto_search": False}, + tools=[], + ) + ) + await client.send(session_message) + + if shutdown_event.is_set(): + return + + # 创建消息处理器 + message_handler = await create_message_handler(client, shutdown_event) + + async def send_audio_with_commit(): + # 发送音频数据 + await send_audio(client, audio_file_path) + # 提交音频缓冲区 + commit_message = InputAudioBufferCommitMessage( + client_timestamp=int(asyncio.get_event_loop().time() * 1000) + ) + await client.send(commit_message) + # 发送创建响应的消息 + await client.send_json({"type": "response.create"}) + + # 创建发送和接收任务 + send_task = asyncio.create_task(send_audio_with_commit()) + receive_task = asyncio.create_task(message_handler.receive_messages()) + + # 等待任务完成 + try: + await asyncio.gather(send_task, receive_task) + except Exception as e: + logger.error(f"任务执行出错: {e}") + # 取消未完成的任务 + for task in [send_task, receive_task]: + if not task.done(): + task.cancel() + try: + await task + except asyncio.CancelledError: + pass + finally: + # 停止连接断开模拟器 + if simulate_breaks: + connection_breaker.stop() + logger.info("已停止连接断开模拟器") + # 关闭客户端连接 + await client.close() + except Exception as e: + logger.error(f"发生错误: {e}") + finally: + if shutdown_event.is_set(): + logger.info("程序已完成退出") + + +if __name__ == "__main__": + load_dotenv() + if len(sys.argv) < 2: + print("使用方法: python low_level_sample_with_reconnect.py <音频文件> [--no-simulate]") + sys.exit(1) + + file_path = sys.argv[1] + if not os.path.exists(file_path): + print(f"音频文件 {file_path} 不存在") + sys.exit(1) + + # 检查是否禁用断开模拟 + simulate_breaks = True + if len(sys.argv) > 2 and sys.argv[2] == "--no-simulate": + simulate_breaks = False + print("断开连接模拟已禁用") + + try: + asyncio.run(with_zhipu(file_path, simulate_breaks)) + except KeyboardInterrupt: + print("\n程序被用户中断") + except Exception as e: + print(f"程序执行出错: {e}") + finally: + print("程序已退出") \ No newline at end of file diff --git a/python/tests/test_low_level_client_reconnect.py b/python/tests/test_low_level_client_reconnect.py new file mode 100644 index 0000000..34dc4b5 --- /dev/null +++ b/python/tests/test_low_level_client_reconnect.py @@ -0,0 +1,179 @@ +# Copyright (c) ZhiPu Corporation. +# Licensed under the MIT License. + +import unittest +import asyncio +import logging +import aiohttp +from unittest.mock import AsyncMock, MagicMock, patch +from aiohttp import ClientSession, WSMsgType, WSServerHandshakeError, ClientConnectorError + +from rtclient.low_level_client import RTLowLevelClient, ConnectionError + + +class TestRTLowLevelClientReconnect(unittest.TestCase): + def setUp(self): + # 设置测试用的日志级别 + logging.basicConfig(level=logging.DEBUG) + # 创建事件循环 + self.loop = asyncio.new_event_loop() + asyncio.set_event_loop(self.loop) + + def tearDown(self): + # 关闭事件循环 + self.loop.close() + + @patch('rtclient.low_level_client.ClientSession') + async def async_test_connect_success(self, mock_session): + # 配置模拟对象 + mock_ws = AsyncMock() + mock_ws.closed = False + mock_session_instance = AsyncMock() + mock_session_instance.ws_connect = AsyncMock(return_value=mock_ws) + mock_session.return_value = mock_session_instance + + # 创建客户端并连接 + client = RTLowLevelClient("wss://example.com/ws") + await client.connect() + + # 验证调用了正确的方法 + mock_session.assert_called_once() + mock_session_instance.ws_connect.assert_called_once() + self.assertFalse(client.closed) + + # 清理 + await client.close() + + @patch('rtclient.low_level_client.ClientSession') + async def async_test_connect_failure_no_retry(self, mock_session): + # 配置模拟对象以抛出异常 + mock_session_instance = AsyncMock() + error = WSServerHandshakeError(request_info=MagicMock(), history=MagicMock(), status=403) + mock_session_instance.ws_connect = AsyncMock(side_effect=error) + mock_session.return_value = mock_session_instance + + # 创建客户端,设置不重试 + client = RTLowLevelClient("wss://example.com/ws", max_retries=0) + + # 尝试连接应该失败 + with self.assertRaises(ConnectionError): + await client.connect() + + # 验证调用情况 + mock_session.assert_called_once() + mock_session_instance.ws_connect.assert_called_once() + self.assertTrue(client.closed) + + @patch('rtclient.low_level_client.ClientSession') + @patch('rtclient.low_level_client.asyncio.sleep', new_callable=AsyncMock) + async def async_test_reconnect_success_after_failure(self, mock_sleep, mock_session): + # 配置模拟对象,第一次连接失败,第二次成功 + mock_ws = AsyncMock() + mock_ws.closed = False + mock_session_instance = AsyncMock() + + # 第一次调用抛出异常,第二次返回成功 + error = WSServerHandshakeError(request_info=MagicMock(), history=MagicMock(), status=500) + mock_session_instance.ws_connect = AsyncMock(side_effect=[error, mock_ws]) + mock_session.return_value = mock_session_instance + + # 创建客户端,设置重试一次 + client = RTLowLevelClient("wss://example.com/ws", max_retries=1, initial_retry_delay=0.1) + + # 连接应该最终成功 + await client.connect() + + # 验证调用情况 + mock_session.assert_called_once() + self.assertEqual(mock_session_instance.ws_connect.call_count, 2) + mock_sleep.assert_called_once() + self.assertFalse(client.closed) + + # 清理 + await client.close() + + @patch('rtclient.low_level_client.ClientSession') + async def async_test_send_with_reconnect(self, mock_session): + # 配置模拟对象 + mock_ws = AsyncMock() + mock_ws.closed = False + mock_ws.send_str = AsyncMock() + + mock_session_instance = AsyncMock() + mock_session_instance.ws_connect = AsyncMock(return_value=mock_ws) + mock_session.return_value = mock_session_instance + + # 创建客户端并连接 + client = RTLowLevelClient("wss://example.com/ws") + await client.connect() + + # 模拟连接断开 + mock_ws.closed = True + + # 设置重连方法 + original_reconnect = client.reconnect + client.reconnect = AsyncMock(side_effect=lambda: setattr(mock_ws, 'closed', False) or original_reconnect.__call__()) + + # 尝试发送消息 + await client.send({"test": "message"}) + + # 验证重连被调用 + client.reconnect.assert_called_once() + # 验证消息被发送 + mock_ws.send_str.assert_called_once() + + # 清理 + await client.close() + + @patch('rtclient.low_level_client.ClientSession') + async def async_test_recv_reconnect_on_error(self, mock_session): + # 配置模拟对象 + mock_ws = AsyncMock() + mock_ws.closed = False + + # 配置receive方法抛出异常 + error_msg = MagicMock() + error_msg.type = WSMsgType.ERROR + error_msg.data = "Connection error" + mock_ws.receive = AsyncMock(return_value=error_msg) + + mock_session_instance = AsyncMock() + mock_session_instance.ws_connect = AsyncMock(return_value=mock_ws) + mock_session.return_value = mock_session_instance + + # 创建客户端并连接 + client = RTLowLevelClient("wss://example.com/ws") + await client.connect() + + # 设置重连方法 + client.reconnect = AsyncMock(return_value=True) + + # 接收消息 + result = await client.recv() + + # 验证重连被调用 + client.reconnect.assert_called_once() + # 返回值应为None + self.assertIsNone(result) + + # 清理 + await client.close() + + def test_connect_success(self): + self.loop.run_until_complete(self.async_test_connect_success()) + + def test_connect_failure_no_retry(self): + self.loop.run_until_complete(self.async_test_connect_failure_no_retry()) + + def test_reconnect_success_after_failure(self): + self.loop.run_until_complete(self.async_test_reconnect_success_after_failure()) + + def test_send_with_reconnect(self): + self.loop.run_until_complete(self.async_test_send_with_reconnect()) + + def test_recv_reconnect_on_error(self): + self.loop.run_until_complete(self.async_test_recv_reconnect_on_error()) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file