Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ local_path_override(
path = "src",
)

bazel_dep(name = "psi", version = "0.6.0.dev250507")
bazel_dep(name = "psi", version = "0.6.0.dev250922")
bazel_dep(name = "yacl", version = "0.4.5b10-nightly-20250110")
git_override(
module_name = "yacl",
remote = "https://github.com/secretflow/yacl.git",
commit = "a8af9f85816139f62712e29f0e5421da6f7f1e32",
)

bazel_dep(name = "grpc", version = "1.66.0.bcr.4")

# pin [email protected]
Expand Down
56 changes: 21 additions & 35 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

363 changes: 363 additions & 0 deletions examples/python/advanced/http_channel.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,363 @@
#!/usr/bin/env python3
"""
HttpChannel - 基于HTTP的IChannel实现 (P0功能)
"""

import threading
import time
import logging
import requests
from typing import Dict, Optional
from http.server import HTTPServer, BaseHTTPRequestHandler
import urllib.parse
import json

import spu.libspu.link as link
import spu.libspu as spu

# 设置日志
logging.basicConfig(
level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)


class HttpChannel(link.IChannel):
"""基于HTTP的IChannel实现"""

def __init__(
self,
local_rank: int,
remote_rank: int,
local_port: int,
remote_port: int,
):
super().__init__()
self.local_rank = local_rank
self.remote_rank = remote_rank
self.local_port = local_port
self.remote_port = remote_port
self.recv_timeout = 10000 # 默认10秒
self.throttle_window_size = 1024
self.chunk_parallel_send_size = 4

# HTTP客户端
self.session = requests.Session()
self.session.timeout = (1, 5) # (connect, read) timeout

def _get_url(self, key: str, operation: str) -> str:
"""构建HTTP URL"""
if operation == "send":
return f"http://localhost:{self.remote_port}/send/{self.local_rank}/{self.remote_rank}/{key}"
else: # recv
return f"http://localhost:{self.local_port}/recv/{self.remote_rank}/{self.local_rank}/{key}"

def SendAsync(self, key: str, buf: bytes) -> None:
"""异步发送数据"""
url = self._get_url(key, "send")
try:
response = self.session.post(url, data=buf)
response.raise_for_status()
except requests.RequestException as e:
raise RuntimeError(
f"网络错误: 无法发送数据到远程节点 {self.remote_rank}: {e}"
)

def SendAsyncThrottled(self, key: str, buf: bytes) -> None:
"""带节流的异步发送"""
self.SendAsync(key, buf)

def Send(self, key: str, value: bytes) -> None:
"""同步发送数据"""
logger.info(
f"[HttpChannel({self.local_rank=}, {self.remote_rank=})] Start to send key: {key}, size: {len(value)} bytes"
)
self.SendAsync(key, value)
logger.info(
f"[HttpChannel({self.local_rank=}, {self.remote_rank=})] Finished to send key: {key}"
)

def Recv(self, key: str) -> bytes:
"""接收数据"""
url = self._get_url(key, "recv")
start_time = time.time()
logger.info(
f"[HttpChannel({self.local_rank=}, {self.remote_rank=})] Start to recv key: {key}"
)

while True:
try:
response = self.session.get(url)
if response.status_code == 200:
logger.info(
f"[HttpChannel({self.local_rank=}, {self.remote_rank=})] Finished to recv key: {key}"
)
return response.content
elif response.status_code == 404:
pass # 消息未到达,继续等待
except requests.RequestException as e:
raise RuntimeError(
f"网络错误: 无法从远程节点 {self.remote_rank} 接收数据: {e}"
)

# 检查超时
if (time.time() - start_time) * 1000 > self.recv_timeout:
raise RuntimeError(
f"超时错误: 等待从远程节点 {self.remote_rank} 接收数据超时"
)

time.sleep(0.5)

def SetRecvTimeout(self, timeout_ms: int) -> None:
"""设置接收超时"""
self.recv_timeout = timeout_ms

def GetRecvTimeout(self) -> int:
"""获取接收超时"""
return self.recv_timeout

def WaitLinkTaskFinish(self) -> None:
"""等待任务完成"""
pass

def Abort(self) -> None:
"""中止操作"""
pass

def SetThrottleWindowSize(self, size: int) -> None:
"""设置节流窗口大小"""
self.throttle_window_size = size

def TestSend(self, timeout: int) -> None:
"""测试发送功能"""
self.Send("test", b"")

def TestRecv(self) -> None:
"""测试接收功能"""
self.Recv("test")

def SetChunkParallelSendSize(self, size: int) -> None:
"""设置并行发送大小"""
self.chunk_parallel_send_size = size


class HttpChannelHandler(BaseHTTPRequestHandler):
"""HTTP Channel服务端处理器"""

def __init__(self, storage: Dict[str, bytes], *args, **kwargs):
self.storage = storage
super().__init__(*args, **kwargs)

def do_POST(self):
"""处理发送请求"""
if self.path.startswith('/send/'):
# 解析URL: /send/{from_rank}/{to_rank}/{key}
parts = self.path.split('/')
if len(parts) >= 5:
from_rank = int(parts[2])
to_rank = int(parts[3])
key = parts[4]

content_length = int(self.headers.get('Content-Length', 0))
data = self.rfile.read(content_length)

final_key = f"{key}_{from_rank}_{to_rank}"
self.storage[final_key] = data

logger.info(f"[STORE] Saved key: {final_key}, size: {len(data)} bytes")

self.send_response(200)
self.end_headers()
return

self.send_response(404)
self.end_headers()

def do_GET(self):
"""处理接收请求"""
if self.path == '/keys':
# 返回所有存储的key
keys = list(self.storage.keys())
response = json.dumps(keys).encode()
logger.info(f"[KEYS] Returning {len(keys)} keys: {keys}")
self.send_response(200)
self.send_header('Content-Type', 'application/json')
self.send_header('Content-Length', str(len(response)))
self.end_headers()
self.wfile.write(response)
return

if self.path.startswith('/recv/'):
# 解析URL: /recv/{from_rank}/{to_rank}/{key}
parts = self.path.split('/')
if len(parts) >= 5:
from_rank = int(parts[2])
to_rank = int(parts[3])
key = parts[4]

final_key = f"{key}_{from_rank}_{to_rank}"
if final_key in self.storage:
data = self.storage[final_key]
# 成功接收后删除数据,防止存储污染
del self.storage[final_key]
logger.info(
f"[RECV] Retrieved and deleted key: {final_key}, size: {len(data)} bytes"
)
self.send_response(200)
self.send_header('Content-Length', str(len(data)))
self.end_headers()
self.wfile.write(data)
return
else:
logger.info(f"[RECV] Key not found: {final_key}")

self.send_response(404)
self.end_headers()


class HttpChannelServer:
"""HTTP Channel服务端管理器"""

def __init__(self, port: int, storage: Dict[str, bytes]):
self.port = port
self.storage = storage
self.server = None
self.thread = None

def start(self):
"""启动HTTP服务端"""
handler = lambda *args, **kwargs: HttpChannelHandler(
self.storage, *args, **kwargs
)
self.server = HTTPServer(('localhost', self.port), handler)

self.thread = threading.Thread(target=self.server.serve_forever, daemon=True)
self.thread.start()

def stop(self):
"""停止HTTP服务端"""
if self.server:
self.server.shutdown()
self.server.server_close()

def __enter__(self):
self.start()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.stop()


def create_http_channel_context(rank: int, world_size: int, base_port: int = 8000):
"""创建使用HttpChannel的Context"""
import spu.libspu as spu

# 创建设备描述
desc = link.Desc()
for i in range(world_size):
desc.add_party(f"party_{i}", f"localhost:{base_port + i}")

# 创建HttpChannel列表
channels = []
for i in range(world_size):
if i == rank:
channels.append(None) # 自己不需要channel
else:
channel = HttpChannel(
local_rank=rank,
remote_rank=i,
local_port=base_port + rank,
remote_port=base_port + i,
)
channels.append(channel)

# 创建Context
return link.create_with_channels(desc, rank, channels)


if __name__ == "__main__":
import multiprocessing
import time

def worker_process(rank: int, world_size: int, base_port: int):
"""工作进程"""
# storage = {}

# # 启动HTTP服务端
# server = HttpChannelServer(base_port + rank, storage)
# server.start()

# # 给其他进程启动时间
# time.sleep(0.1)

try:
# 创建Context
logger.info(f"Rank {rank}: 创建HttpChannel Context")
ctx = create_http_channel_context(rank, world_size, base_port)
logger.info(f"Rank {rank}: Context创建成功")

# 验证Context功能
logger.info(f"Rank {rank}: 验证Context功能")
logger.info(
f"Rank {rank}: Context rank={ctx.rank}, world_size={ctx.world_size}"
)

# 简单测试 - 测试正常交互
logger.info(f"Rank {rank}: HttpChannel测试开始")
if rank == 0:
# Rank 0: 创建到Rank 1的通道并发送数据
test_channel = HttpChannel(
local_rank=0,
remote_rank=1,
local_port=base_port,
remote_port=base_port + 1,
)
data = b"hello from rank 0"
logger.info(f"Rank {rank}: 准备发送数据: '{data.decode()}'")
try:
test_channel.Send('hello_key', data)
logger.info(f"Rank {rank}: 已发送 '{data.decode()}'")
except RuntimeError as e:
logger.error(f"Rank {rank}: 发送失败 - {e}")
except Exception as e:
logger.error(f"Rank {rank}: 发送过程发生其他错误 - {e}")
else:
# Rank 1: 创建到Rank 0的通道并接收数据
test_channel = HttpChannel(
local_rank=1,
remote_rank=0,
local_port=base_port + 1,
remote_port=base_port,
)
test_channel.SetRecvTimeout(3000) # 3秒超时

logger.info(f"Rank {rank}: 等待接收数据")
try:
received = test_channel.Recv('hello_key')
logger.info(f"Rank {rank}: 已接收 '{received.decode()}'")
except RuntimeError as e:
logger.error(f"Rank {rank}: 接收失败 - {e}")
except Exception as e:
logger.error(f"Rank {rank}: 接收过程发生其他错误 - {e}")

finally:
# server.stop()
pass

# 测试运行
world_size = 2
# base_port = 9000
base_port = 61930

processes = []
for rank in range(world_size):
p = multiprocessing.Process(
target=worker_process, args=(rank, world_size, base_port)
)
processes.append(p)
p.start()

for p in processes:
p.join()

print("HttpChannel测试完成")
Loading
Loading