diff --git a/MODULE.bazel b/MODULE.bazel index 522b8a7a..8f20d0fd 100644 --- a/MODULE.bazel +++ b/MODULE.bazel @@ -36,8 +36,14 @@ local_path_override( path = "src", ) -bazel_dep(name = "psi", version = "0.6.0.dev250507") +bazel_dep(name = "psi", version = "0.6.0.dev250922") bazel_dep(name = "yacl", version = "0.4.5b10-nightly-20250110") +git_override( + module_name = "yacl", + remote = "https://github.com/secretflow/yacl.git", + commit = "a8af9f85816139f62712e29f0e5421da6f7f1e32", +) + bazel_dep(name = "grpc", version = "1.66.0.bcr.4") # pin grpc@1.66.0.bcr.4 diff --git a/MODULE.bazel.lock b/MODULE.bazel.lock index 41b2360e..c05ae3d9 100644 --- a/MODULE.bazel.lock +++ b/MODULE.bazel.lock @@ -37,8 +37,9 @@ "https://bcr.bazel.build/modules/bazel_features/1.19.0/MODULE.bazel": "59adcdf28230d220f0067b1f435b8537dd033bfff8db21335ef9217919c7fb58", "https://bcr.bazel.build/modules/bazel_features/1.20.0/MODULE.bazel": "8b85300b9c8594752e0721a37210e34879d23adc219ed9dc8f4104a4a1750920", "https://bcr.bazel.build/modules/bazel_features/1.21.0/MODULE.bazel": "675642261665d8eea09989aa3b8afb5c37627f1be178382c320d1b46afba5e3b", - "https://bcr.bazel.build/modules/bazel_features/1.21.0/source.json": "3e8379efaaef53ce35b7b8ba419df829315a880cb0a030e5bb45c96d6d5ecb5f", "https://bcr.bazel.build/modules/bazel_features/1.3.0/MODULE.bazel": "cdcafe83ec318cda34e02948e81d790aab8df7a929cec6f6969f13a489ccecd9", + "https://bcr.bazel.build/modules/bazel_features/1.35.0/MODULE.bazel": "3d9393e5317df8afcfc509458591874ea734fa68ecbdd64fbd6c2c0cbe399526", + "https://bcr.bazel.build/modules/bazel_features/1.35.0/source.json": "c61e98cb3573ce0b8d69eb77c652ab10545375e387e45005e7f8e84792472b09", "https://bcr.bazel.build/modules/bazel_features/1.4.1/MODULE.bazel": "e45b6bb2350aff3e442ae1111c555e27eac1d915e77775f6fdc4b351b758b5d7", "https://bcr.bazel.build/modules/bazel_features/1.9.0/MODULE.bazel": "885151d58d90d8d9c811eb75e3288c11f850e1d6b481a8c9f766adee4712358b", "https://bcr.bazel.build/modules/bazel_features/1.9.1/MODULE.bazel": "8f679097876a9b609ad1f60249c49d68bfab783dd9be012faf9d82547b14815a", @@ -262,13 +263,14 @@ "https://bcr.bazel.build/modules/opentracing-cpp/1.6.0/source.json": "da1cb1add160f5e5074b7272e9db6fd8f1b3336c15032cd0a653af9d2f484aed", "https://bcr.bazel.build/modules/platforms/0.0.10/MODULE.bazel": "8cb8efaf200bdeb2150d93e162c40f388529a25852b332cec879373771e48ed5", "https://bcr.bazel.build/modules/platforms/0.0.11/MODULE.bazel": "0daefc49732e227caa8bfa834d65dc52e8cc18a2faf80df25e8caea151a9413f", - "https://bcr.bazel.build/modules/platforms/0.0.11/source.json": "f7e188b79ebedebfe75e9e1d098b8845226c7992b307e28e1496f23112e8fc29", "https://bcr.bazel.build/modules/platforms/0.0.4/MODULE.bazel": "9b328e31ee156f53f3c416a64f8491f7eb731742655a47c9eec4703a71644aee", "https://bcr.bazel.build/modules/platforms/0.0.5/MODULE.bazel": "5733b54ea419d5eaf7997054bb55f6a1d0b5ff8aedf0176fef9eea44f3acda37", "https://bcr.bazel.build/modules/platforms/0.0.6/MODULE.bazel": "ad6eeef431dc52aefd2d77ed20a4b353f8ebf0f4ecdd26a807d2da5aa8cd0615", "https://bcr.bazel.build/modules/platforms/0.0.7/MODULE.bazel": "72fd4a0ede9ee5c021f6a8dd92b503e089f46c227ba2813ff183b71616034814", "https://bcr.bazel.build/modules/platforms/0.0.8/MODULE.bazel": "9f142c03e348f6d263719f5074b21ef3adf0b139ee4c5133e2aa35664da9eb2d", "https://bcr.bazel.build/modules/platforms/0.0.9/MODULE.bazel": "4a87a60c927b56ddd67db50c89acaa62f4ce2a1d2149ccb63ffd871d5ce29ebc", + "https://bcr.bazel.build/modules/platforms/1.0.0/MODULE.bazel": "f05feb42b48f1b3c225e4ccf351f367be0371411a803198ec34a389fb22aa580", + "https://bcr.bazel.build/modules/platforms/1.0.0/source.json": "f4ff1fd412e0246fd38c82328eb209130ead81d62dcd5a9e40910f867f733d96", "https://bcr.bazel.build/modules/prometheus-cpp/1.2.4/MODULE.bazel": "0fbe5dcff66311947a3f6b86ebc6a6d9328e31a28413ca864debc4a043f371e5", "https://bcr.bazel.build/modules/prometheus-cpp/1.2.4/source.json": "aa58bb10d0bb0dcaf4ad2c509ddcec23d2e94c3935e21517a5adbc2363248a55", "https://bcr.bazel.build/modules/protobuf/27.3/MODULE.bazel": "d94898cbf9d6d25c0edca2521211413506b68a109a6b01776832ed25154d23d7", @@ -452,6 +454,7 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.20.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.21.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.3.0/MODULE.bazel": "not found", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.35.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.4.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.9.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.9.1/MODULE.bazel": "not found", @@ -604,11 +607,10 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/gsl/4.0.0/source.json": "692503164338d148de46f5812b80bbb9b2ae719e683b0c889ece13f038525cbf", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/hash_drbg/0.0.0-20230516-2411fa9/MODULE.bazel": "12ca3c056d6d524b1c68e495e1d78dd23a07a1537f88e06e62becf1f4beb0e3f", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/hash_drbg/0.0.0-20230516-2411fa9/source.json": "6b69146300eb6ae35f1364d9f558c6c0503e70b399523e7b8fa8a6e831304151", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/heu/0.6.0.dev20250123/MODULE.bazel": "ce9d18ff97cef3f31bffa08af02118b98c815e647ec390fe715af87eaa9e7253", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/heu/0.6.0.dev20250123/source.json": "2692b9d8ec2ed2361b00c1f32bb54ea8dacae6e10ab0260a7d14cd0859c0c255", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/heu/0.6.0b0/MODULE.bazel": "119b50121b05034fb36a26b55d1f82b0c1509259b2f017d5e5a48c843545eb51", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/heu/0.6.0b0/source.json": "c11bccc312bce75c44237699a9adbaa2814e92dfb03998d624eee445f982f6a7", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/ippcp/2021.8.bcr.2/MODULE.bazel": "34e4bf82778258a13393ecf7828fee728124c3d4c07941fed56f5fcafe73f04d", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/ippcp/2021.8.bcr.2/source.json": "4a4a758ddf4292e9af6f7c6b1cd014758982a54c2dd3c50fb1912ceec68a5b69", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/ippcp/2021.8/MODULE.bazel": "b20ea49a112714f9fbfb5d2df60ca64418b9d736fa02c4d6ba36efff2108a117", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/jsoncpp/1.9.5/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/kuscia/0.14.0b0/MODULE.bazel": "ca497d1769276104afd4df13c080db068a790e0e073a13ff5564520ea695f312", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/kuscia/0.14.0b0/source.json": "03e55f9cfeaef0339a5c5f0bda7ae1cbccad98203a9cfc28023af16d0db1b719", @@ -618,6 +620,8 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/lib25519/20240321/source.json": "17672e2c227edd4b9a17fcc1d078fd94ae86ad1f0c78fc491880bf03565045ed", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/libdivide/5.0/MODULE.bazel": "2936817667b364892e0ef46c906de17c4bd967838a11181a1e036769c75c9fd7", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/libdivide/5.0/source.json": "41566db4fdc0cdfa10b74ed7f6c9b5fa99f504fecfc63c95b87f0edc673cc8ef", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/liboqs/0.13.0/MODULE.bazel": "a783966a6e3a205cde9e3dd7cd386cf8c60a4d57f346d1cbdeee8964da2fdd45", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/liboqs/0.13.0/source.json": "f98ef255946fe020308176f1fa43455e7290475ae741804c5ab57484621f18a0", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/libpfm/4.11.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/libsodium/1.0.18/MODULE.bazel": "0c5efe7944f6cf929c6b6414b539ceec47305cfc69856b1c7c9dc066016d50c5", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/libsodium/1.0.18/source.json": "85abb5d41e1f38b3909f24a92eca547d337bee936649d39461037b0582ec6cd1", @@ -645,8 +649,10 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/orc/1.9.3/source.json": "c90eaffa01ecac7c1c2e87247d449c9eb03064218476226d0ad1dfcfdf199787", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/org_interconnection/0.0.1/MODULE.bazel": "6ae177de72be5a9b49251832c2cea0ca070f11c2da3422c6bffce6263a5c78b4", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/org_interconnection/0.0.1/source.json": "2fe3f4126c5b82d21b5b1ff77d6bd8dc57cc3fb305a80d6d4fee5e5c8932d720", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pailliercryptolib/2.0.0-20231102-fdc2135/MODULE.bazel": "1ddfd40690e4955de2faf5a19093c4634ffe95cea4e618da80cba8fd6284aa29", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pailliercryptolib/2.0.0-20231102-fdc2135/source.json": "25ad090a3cc56c8b1df8fa968c01ac98a2e2b3905062aa6b1f6d49ae9ce0975c", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pailliercryptolib/2.0.0-20231102-fdc2135.bcr.1/MODULE.bazel": "1cea995d1978ba6de6c079ea466b49d8772c8b904370050c914104471218c801", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pailliercryptolib/2.0.0-20231102-fdc2135.bcr.1/source.json": "922aa56a43cb19364136afad0423c5c3a090d93cef61680426a5426f6b361a14", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/perfetto/41.0/MODULE.bazel": "c86f3f4f13e9dacdf129bc9df85d0b5f60e2e26c640185b950c0bcfd16fb3989", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/perfetto/41.0/source.json": "6534b3ed397e2c35eb9f5bf7eb553d8b78bafa66432595faa026fb09b0f77187", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.10/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.11/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.4/MODULE.bazel": "not found", @@ -655,10 +661,11 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.7/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.8/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.9/MODULE.bazel": "not found", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/1.0.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/prometheus-cpp/1.2.4/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/protobuf/27.3/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/psi/0.6.0.dev250507/MODULE.bazel": "ec404606c7a7f4d775eee6688e43d81c79589e1f575562f2d39f989f76bec613", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/psi/0.6.0.dev250507/source.json": "9827e79acd975fb6d43b2b5e7ec4f816645a81cb44a02f3377ec92e4d8c86e2d", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/psi/0.6.0.dev250922/MODULE.bazel": "a1119d36ddc23c7ce04c8bab685b33cd45b0eebd6ce1848dceabc9a686457973", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/psi/0.6.0.dev250922/source.json": "a0c75770358c252aefd0593a6e359f7289e235575e1708f373d9c92c2c7bdb4c", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pybind11_bazel/2.11.1.bzl.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pybind11_bazel/2.11.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pybind11_bazel/2.12.0/MODULE.bazel": "not found", @@ -780,9 +787,6 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/xla/20240814.0-64bdcc5/MODULE.bazel": "5348a789c31fdf43215efa633e471bf2851ff5c5044ac2a7fd04da3a3746b72d", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/xla/20240814.0-64bdcc5/source.json": "2045ff3085f2e0ea4e391d7bde507fa0689256de98bc748370e5a5c301c1cba3", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/xsimd/8.1.0/MODULE.bazel": "not found", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/yacl/0.4.5b10-nightly-20241224/MODULE.bazel": "b8503c783401674db7cfb194e14b96cac57973dc596ce4bd60c90ceaaf939c70", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/yacl/0.4.5b10-nightly-20250110/MODULE.bazel": "984e37b3d6d982edb1430f6dc100607569f78378bc3b662ac03802671d0b7c7a", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/yacl/0.4.5b10-nightly-20250110/source.json": "8ca56ae0ba48aefadf631c303ef3e0c3bb3cc2f6e0e717264a32b633c5dc4b05", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/zlib/1.2.11/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/zlib/1.2.13/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/zlib/1.3.1.bcr.1/MODULE.bazel": "not found", @@ -1157,8 +1161,8 @@ }, "@@psi~//bazel:defs.bzl%non_module_dependencies": { "general": { - "bzlTransitiveDigest": "Lof+AoGR1XS+eNcPLWJjUy4ZzoavywKWa/MBlWLYpK8=", - "usagesDigest": "lXD3AjgUsQBfotJ7azSny4UVe+Cj15ywb24Kj21VmhE=", + "bzlTransitiveDigest": "NAC0+n0skvguFGiSWQSZMhbrQQVV+nNsrNIVDde/xQE=", + "usagesDigest": "i++UtNfvCLwMPos+69X3VTHMWPIot+MVMCyxNJf6O24=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, @@ -1215,24 +1219,6 @@ "build_file": "@@psi~//bazel:flatbuffers.BUILD" } }, - "perfetto": { - "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", - "ruleClassName": "http_archive", - "attributes": { - "urls": [ - "https://github.com/google/perfetto/archive/refs/tags/v41.0.tar.gz" - ], - "sha256": "4c8fe8a609fcc77ca653ec85f387ab6c3a048fcd8df9275a1aa8087984b89db8", - "strip_prefix": "perfetto-41.0", - "patch_args": [ - "-p1" - ], - "patches": [ - "@@psi~//bazel/patches:perfetto.patch" - ], - "build_file": "@@psi~//bazel:perfetto.BUILD" - } - }, "curve25519-donna": { "bzlFile": "@@bazel_tools//tools/build_defs/repo:http.bzl", "ruleClassName": "http_archive", @@ -1351,7 +1337,7 @@ "@@rules_cuda~//cuda:extensions.bzl%toolchain": { "general": { "bzlTransitiveDigest": "BnYM6/SSxkN/7InBOUBKIuviq2l1hk6LC7EDB59vI80=", - "usagesDigest": "/PYgoaRSNStmRa+ktMi63gZXkhZAcbuHjIPOuRZF8eM=", + "usagesDigest": "dR2MXEjxbM09d2eMJ/CR3hyqAlqCG/1ZdJrCmfp62bY=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, @@ -1982,7 +1968,7 @@ }, "@@rules_rust~//rust:extensions.bzl%rust": { "general": { - "bzlTransitiveDigest": "gBMm01ORsD1DBDkqspfmslCPlk4Td2Es4y9IGEy3T3w=", + "bzlTransitiveDigest": "68oGaxymy7u46xLM6Q980k5SyU14DzQUsfgpnJhXtw4=", "usagesDigest": "rV+PuweiVabXYal8kLLNKaL4Tx6LbDJmoCxxMI6XmeA=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, @@ -5071,7 +5057,7 @@ "@@spulib~//bazel:defs.bzl%non_module_dependencies": { "general": { "bzlTransitiveDigest": "JT8ZLEUdrYXN19gijrHtztFq/cEAhJlRlNjhtQUlDIE=", - "usagesDigest": "1vLMMaI5WnBmZ0/J61cSbPkEvHvnDS9gZsDaw4IcxWo=", + "usagesDigest": "bCnIBcq1eSgmVRQ9T7XkKrmvSugldUtdEYYJqvh/RIY=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {}, diff --git a/examples/python/advanced/http_channel.py b/examples/python/advanced/http_channel.py new file mode 100644 index 00000000..17ef01ca --- /dev/null +++ b/examples/python/advanced/http_channel.py @@ -0,0 +1,363 @@ +#!/usr/bin/env python3 +""" +HttpChannel - 基于HTTP的IChannel实现 (P0功能) +""" + +import threading +import time +import logging +import requests +from typing import Dict, Optional +from http.server import HTTPServer, BaseHTTPRequestHandler +import urllib.parse +import json + +import spu.libspu.link as link +import spu.libspu as spu + +# 设置日志 +logging.basicConfig( + level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s' +) +logger = logging.getLogger(__name__) + + +class HttpChannel(link.IChannel): + """基于HTTP的IChannel实现""" + + def __init__( + self, + local_rank: int, + remote_rank: int, + local_port: int, + remote_port: int, + ): + super().__init__() + self.local_rank = local_rank + self.remote_rank = remote_rank + self.local_port = local_port + self.remote_port = remote_port + self.recv_timeout = 10000 # 默认10秒 + self.throttle_window_size = 1024 + self.chunk_parallel_send_size = 4 + + # HTTP客户端 + self.session = requests.Session() + self.session.timeout = (1, 5) # (connect, read) timeout + + def _get_url(self, key: str, operation: str) -> str: + """构建HTTP URL""" + if operation == "send": + return f"http://localhost:{self.remote_port}/send/{self.local_rank}/{self.remote_rank}/{key}" + else: # recv + return f"http://localhost:{self.local_port}/recv/{self.remote_rank}/{self.local_rank}/{key}" + + def SendAsync(self, key: str, buf: bytes) -> None: + """异步发送数据""" + url = self._get_url(key, "send") + try: + response = self.session.post(url, data=buf) + response.raise_for_status() + except requests.RequestException as e: + raise RuntimeError( + f"网络错误: 无法发送数据到远程节点 {self.remote_rank}: {e}" + ) + + def SendAsyncThrottled(self, key: str, buf: bytes) -> None: + """带节流的异步发送""" + self.SendAsync(key, buf) + + def Send(self, key: str, value: bytes) -> None: + """同步发送数据""" + logger.info( + f"[HttpChannel({self.local_rank=}, {self.remote_rank=})] Start to send key: {key}, size: {len(value)} bytes" + ) + self.SendAsync(key, value) + logger.info( + f"[HttpChannel({self.local_rank=}, {self.remote_rank=})] Finished to send key: {key}" + ) + + def Recv(self, key: str) -> bytes: + """接收数据""" + url = self._get_url(key, "recv") + start_time = time.time() + logger.info( + f"[HttpChannel({self.local_rank=}, {self.remote_rank=})] Start to recv key: {key}" + ) + + while True: + try: + response = self.session.get(url) + if response.status_code == 200: + logger.info( + f"[HttpChannel({self.local_rank=}, {self.remote_rank=})] Finished to recv key: {key}" + ) + return response.content + elif response.status_code == 404: + pass # 消息未到达,继续等待 + except requests.RequestException as e: + raise RuntimeError( + f"网络错误: 无法从远程节点 {self.remote_rank} 接收数据: {e}" + ) + + # 检查超时 + if (time.time() - start_time) * 1000 > self.recv_timeout: + raise RuntimeError( + f"超时错误: 等待从远程节点 {self.remote_rank} 接收数据超时" + ) + + time.sleep(0.5) + + def SetRecvTimeout(self, timeout_ms: int) -> None: + """设置接收超时""" + self.recv_timeout = timeout_ms + + def GetRecvTimeout(self) -> int: + """获取接收超时""" + return self.recv_timeout + + def WaitLinkTaskFinish(self) -> None: + """等待任务完成""" + pass + + def Abort(self) -> None: + """中止操作""" + pass + + def SetThrottleWindowSize(self, size: int) -> None: + """设置节流窗口大小""" + self.throttle_window_size = size + + def TestSend(self, timeout: int) -> None: + """测试发送功能""" + self.Send("test", b"") + + def TestRecv(self) -> None: + """测试接收功能""" + self.Recv("test") + + def SetChunkParallelSendSize(self, size: int) -> None: + """设置并行发送大小""" + self.chunk_parallel_send_size = size + + +class HttpChannelHandler(BaseHTTPRequestHandler): + """HTTP Channel服务端处理器""" + + def __init__(self, storage: Dict[str, bytes], *args, **kwargs): + self.storage = storage + super().__init__(*args, **kwargs) + + def do_POST(self): + """处理发送请求""" + if self.path.startswith('/send/'): + # 解析URL: /send/{from_rank}/{to_rank}/{key} + parts = self.path.split('/') + if len(parts) >= 5: + from_rank = int(parts[2]) + to_rank = int(parts[3]) + key = parts[4] + + content_length = int(self.headers.get('Content-Length', 0)) + data = self.rfile.read(content_length) + + final_key = f"{key}_{from_rank}_{to_rank}" + self.storage[final_key] = data + + logger.info(f"[STORE] Saved key: {final_key}, size: {len(data)} bytes") + + self.send_response(200) + self.end_headers() + return + + self.send_response(404) + self.end_headers() + + def do_GET(self): + """处理接收请求""" + if self.path == '/keys': + # 返回所有存储的key + keys = list(self.storage.keys()) + response = json.dumps(keys).encode() + logger.info(f"[KEYS] Returning {len(keys)} keys: {keys}") + self.send_response(200) + self.send_header('Content-Type', 'application/json') + self.send_header('Content-Length', str(len(response))) + self.end_headers() + self.wfile.write(response) + return + + if self.path.startswith('/recv/'): + # 解析URL: /recv/{from_rank}/{to_rank}/{key} + parts = self.path.split('/') + if len(parts) >= 5: + from_rank = int(parts[2]) + to_rank = int(parts[3]) + key = parts[4] + + final_key = f"{key}_{from_rank}_{to_rank}" + if final_key in self.storage: + data = self.storage[final_key] + # 成功接收后删除数据,防止存储污染 + del self.storage[final_key] + logger.info( + f"[RECV] Retrieved and deleted key: {final_key}, size: {len(data)} bytes" + ) + self.send_response(200) + self.send_header('Content-Length', str(len(data))) + self.end_headers() + self.wfile.write(data) + return + else: + logger.info(f"[RECV] Key not found: {final_key}") + + self.send_response(404) + self.end_headers() + + +class HttpChannelServer: + """HTTP Channel服务端管理器""" + + def __init__(self, port: int, storage: Dict[str, bytes]): + self.port = port + self.storage = storage + self.server = None + self.thread = None + + def start(self): + """启动HTTP服务端""" + handler = lambda *args, **kwargs: HttpChannelHandler( + self.storage, *args, **kwargs + ) + self.server = HTTPServer(('localhost', self.port), handler) + + self.thread = threading.Thread(target=self.server.serve_forever, daemon=True) + self.thread.start() + + def stop(self): + """停止HTTP服务端""" + if self.server: + self.server.shutdown() + self.server.server_close() + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.stop() + + +def create_http_channel_context(rank: int, world_size: int, base_port: int = 8000): + """创建使用HttpChannel的Context""" + import spu.libspu as spu + + # 创建设备描述 + desc = link.Desc() + for i in range(world_size): + desc.add_party(f"party_{i}", f"localhost:{base_port + i}") + + # 创建HttpChannel列表 + channels = [] + for i in range(world_size): + if i == rank: + channels.append(None) # 自己不需要channel + else: + channel = HttpChannel( + local_rank=rank, + remote_rank=i, + local_port=base_port + rank, + remote_port=base_port + i, + ) + channels.append(channel) + + # 创建Context + return link.create_with_channels(desc, rank, channels) + + +if __name__ == "__main__": + import multiprocessing + import time + + def worker_process(rank: int, world_size: int, base_port: int): + """工作进程""" + # storage = {} + + # # 启动HTTP服务端 + # server = HttpChannelServer(base_port + rank, storage) + # server.start() + + # # 给其他进程启动时间 + # time.sleep(0.1) + + try: + # 创建Context + logger.info(f"Rank {rank}: 创建HttpChannel Context") + ctx = create_http_channel_context(rank, world_size, base_port) + logger.info(f"Rank {rank}: Context创建成功") + + # 验证Context功能 + logger.info(f"Rank {rank}: 验证Context功能") + logger.info( + f"Rank {rank}: Context rank={ctx.rank}, world_size={ctx.world_size}" + ) + + # 简单测试 - 测试正常交互 + logger.info(f"Rank {rank}: HttpChannel测试开始") + if rank == 0: + # Rank 0: 创建到Rank 1的通道并发送数据 + test_channel = HttpChannel( + local_rank=0, + remote_rank=1, + local_port=base_port, + remote_port=base_port + 1, + ) + data = b"hello from rank 0" + logger.info(f"Rank {rank}: 准备发送数据: '{data.decode()}'") + try: + test_channel.Send('hello_key', data) + logger.info(f"Rank {rank}: 已发送 '{data.decode()}'") + except RuntimeError as e: + logger.error(f"Rank {rank}: 发送失败 - {e}") + except Exception as e: + logger.error(f"Rank {rank}: 发送过程发生其他错误 - {e}") + else: + # Rank 1: 创建到Rank 0的通道并接收数据 + test_channel = HttpChannel( + local_rank=1, + remote_rank=0, + local_port=base_port + 1, + remote_port=base_port, + ) + test_channel.SetRecvTimeout(3000) # 3秒超时 + + logger.info(f"Rank {rank}: 等待接收数据") + try: + received = test_channel.Recv('hello_key') + logger.info(f"Rank {rank}: 已接收 '{received.decode()}'") + except RuntimeError as e: + logger.error(f"Rank {rank}: 接收失败 - {e}") + except Exception as e: + logger.error(f"Rank {rank}: 接收过程发生其他错误 - {e}") + + finally: + # server.stop() + pass + + # 测试运行 + world_size = 2 + # base_port = 9000 + base_port = 61930 + + processes = [] + for rank in range(world_size): + p = multiprocessing.Process( + target=worker_process, args=(rank, world_size, base_port) + ) + processes.append(p) + p.start() + + for p in processes: + p.join() + + print("HttpChannel测试完成") diff --git a/examples/python/advanced/test_channel_integration.py b/examples/python/advanced/test_channel_integration.py new file mode 100644 index 00000000..eefc2b7d --- /dev/null +++ b/examples/python/advanced/test_channel_integration.py @@ -0,0 +1,166 @@ +#!/usr/bin/env python3 +""" +Test script: Verify Python class implementing IChannel and passing to C++ +""" + +import spu.libspu.link as link +import spu.libspu as spu +from typing import Dict, Optional + +print(dir(link)) +print(spu.__file__) + + +class SimpleChannel(link.IChannel): + """Simple point-to-point channel implementation""" + + def __init__( + self, name: str, storage: Dict[str, bytes], local_rank: int, remote_rank: int + ): + super().__init__() + self.name = name + self.local_rank = local_rank + self.remote_rank = remote_rank + # Each channel has its own storage, peer_storage allows sharing between two channels + self.storage = storage + self.recv_timeout = 1000 + self.throttle_window_size = 1024 + self.chunk_parallel_send_size = 4 + + def SendAsync(self, key: str, buf: bytes) -> None: + """Asynchronously send data""" + final_key = f"{key}_{self.local_rank}_{self.remote_rank}" + print(f"[{self.name}] SendAsync: key={final_key}, size={len(buf)}") + self.storage[final_key] = buf + + def SendAsyncThrottled(self, key: str, buf: bytes) -> None: + """Asynchronously send data with throttling""" + final_key = f"{key}_{self.local_rank}_{self.remote_rank}" + print(f"[{self.name}] SendAsyncThrottled: key={final_key}, size={len(buf)}") + self.storage[final_key] = buf + + def Send(self, key: str, value: bytes) -> None: + """Synchronously send data""" + final_key = f"{key}_{self.local_rank}_{self.remote_rank}" + print(f"[{self.name}] Send: key={final_key}, size={len(value)}") + self.storage[final_key] = value + + def Recv(self, key: str) -> bytes: + """Receive data""" + final_key = f"{key}_{self.remote_rank}_{self.local_rank}" + print(f"[{self.name}] Recv: key={final_key}") + return self.storage.get(final_key, b"") + + def SetRecvTimeout(self, timeout_ms: int) -> None: + """Set receive timeout""" + print(f"[{self.name}] SetRecvTimeout: {timeout_ms}ms") + self.recv_timeout = timeout_ms + + def GetRecvTimeout(self) -> int: + """Get receive timeout""" + return self.recv_timeout + + def WaitLinkTaskFinish(self) -> None: + """Wait for link tasks to finish""" + print(f"[{self.name}] WaitLinkTaskFinish") + + def Abort(self) -> None: + """Abort operation""" + print(f"[{self.name}] Abort") + + def SetThrottleWindowSize(self, size: int) -> None: + """Set throttle window size""" + print(f"[{self.name}] SetThrottleWindowSize: {size}") + self.throttle_window_size = size + + def TestSend(self, timeout: int) -> None: + """Test send functionality""" + self.Send("test", b"") + + def TestRecv(self) -> None: + """Test receive functionality""" + self.Recv("test") + + def SetChunkParallelSendSize(self, size: int) -> None: + """Set chunk parallel send size""" + print(f"[{self.name}] SetChunkParallelSendSize: {size}") + self.chunk_parallel_send_size = size + + +def test_basic_functionality(): + """Test basic functionality""" + print("=== Test Basic Functionality ===") + + # Create two channels with shared storage + storage = {} + alice = SimpleChannel("Alice", storage, local_rank=0, remote_rank=1) + bob = SimpleChannel("Bob", storage, local_rank=1, remote_rank=0) + + # Test basic communication + alice.Send("test_message", b"hello bob") + received = bob.Recv("test_message") + print(f"Bob received: {received}") + + # Reverse communication + bob.Send("response", b"hello alice") + received_back = alice.Recv("response") + print(f"Alice received: {received_back}") + + +def test_with_create_with_channels(): + """Test create_with_channels interface""" + print("\n=== Test create_with_channels Interface ===") + + try: + + # For simplicity, each rank uses independent channels + # In real scenarios, these channels would connect via network + storage = {} + channelA2B = SimpleChannel("ChannelA2B", storage, local_rank=0, remote_rank=1) + channelB2A = SimpleChannel("ChannelB2A", storage, local_rank=1, remote_rank=0) + + # Create device description + desc = link.Desc() + # Add required parties for the test + desc.add_party("party_0", "127.0.0.1:9000") + desc.add_party("party_1", "127.0.0.1:9001") + + # Create device contexts with custom channels + ctxA = link.create_with_channels(desc, 0, [None, channelA2B]) + ctxB = link.create_with_channels(desc, 1, [channelB2A, None]) + + print("✅ create_with_channels interface call successful") + + # Test basic context functionality + print(f"Alice context rank: {ctxA.rank}") + print(f"Bob context rank: {ctxB.rank}") + print(f"Alice world size: {ctxA.world_size}") + print(f"Bob world size: {ctxB.world_size}") + + # Test context communication interfaces + print("Testing context Send and Recv...") + test_msg = "hello from Alice via context" + ctxA.send(1, test_msg) + received = ctxB.recv(0) + print(f"Bob received via context: {received}") + + test_msg_back = "hello from Bob via context" + ctxB.send(0, test_msg_back) + received_back = ctxA.recv(1) + print(f"Alice received via context: {received_back}") + + # Test async send + ctxA.send_async(1, "async from Alice") + received_async = ctxB.recv(0) + print(f"Bob received async: {received_async}") + + except Exception as e: + print(f"❌ create_with_channels interface call failed: {e}") + + +if __name__ == "__main__": + print("Starting IChannel integration verification...") + + # Run tests + test_basic_functionality() + test_with_create_with_channels() diff --git a/examples/python/ir_dump/ir_dump.py b/examples/python/ir_dump/ir_dump.py index 322a34e6..309f2bff 100644 --- a/examples/python/ir_dump/ir_dump.py +++ b/examples/python/ir_dump/ir_dump.py @@ -43,7 +43,7 @@ copts = libspu.CompilerOptions() copts.enable_pretty_print = True copts.pretty_print_dump_dir = dump_path -copts.xla_pp_kind = 2 +copts.xla_pp_kind = libspu.XLAPrettyPrintKind.HTML def func(x, y): diff --git a/examples/python/utils/distributed_impl.py b/examples/python/utils/distributed_impl.py index bee98908..ead64c1a 100644 --- a/examples/python/utils/distributed_impl.py +++ b/examples/python/utils/distributed_impl.py @@ -547,7 +547,28 @@ def builtin_spu_init( desc.http_max_payload_size = 32 * 1024 * 1024 # Default set link payload to 32M for rank, addr in enumerate(addrs): desc.add_party(f"r{rank}", addr) - link = libspu.link.create_brpc(desc, my_rank) + + # create HttpChannels + from examples.python.advanced.http_channel import HttpChannel + + my_port = int(addrs[my_rank].split(":")[-1]) + channels = [] + for rank, addr in enumerate(addrs): + if rank == my_rank: + channels.append(None) + continue + port = int(addr.split(":")[-1]) + channel = HttpChannel( + local_rank=my_rank, + remote_rank=rank, + local_port=my_port, + remote_port=port, + ) + channels.append(channel) + # create link with channels + link = libspu.link.create_with_channels(desc, my_rank, channels) + + # link = libspu.link.create_brpc(desc, my_rank) spu_config = libspu.RuntimeConfig() spu_config.ParseFromString(spu_config_str) if my_rank != 0: diff --git a/examples/python/utils/nodectl.py b/examples/python/utils/nodectl.py index d9a8123b..17326041 100644 --- a/examples/python/utils/nodectl.py +++ b/examples/python/utils/nodectl.py @@ -14,9 +14,18 @@ import argparse import json +import logging + +# Configure logging +logging.basicConfig( + level=logging.INFO, + format='[%(asctime)s] [%(processName)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', +) import examples.python.utils.distributed as ppd from spu.utils.polyfill import Process +from examples.python.advanced.http_channel import HttpChannelServer parser = argparse.ArgumentParser(description='SPU node service.') parser.add_argument( @@ -41,13 +50,40 @@ if args.command == 'start': ppd.RPC.serve(args.node_id, nodes_def) elif args.command == 'up': + # Start HTTP servers for HttpChannel communication + spu_devices = [d for d in devices_def.values() if d.get('kind') == 'SPU'] + http_servers = [] + logging.info("Starting HTTP servers for HttpChannel communication") + if spu_devices: + # Get all unique SPU internal addresses + all_spu_addrs = set() + for spu_def in spu_devices: + spu_config = spu_def['config'] + if 'spu_internal_addrs' in spu_config: + for addr in spu_config['spu_internal_addrs']: + all_spu_addrs.add(addr) + + # Start HTTP servers for each SPU internal address + for addr in sorted(all_spu_addrs): + port = int(addr.split(":")[-1]) + storage = {} + server = HttpChannelServer(port, storage) + server.start() + http_servers.append(server) + logging.info(f"HTTP server started on port {port}") + workers = [] for node_id in nodes_def.keys(): worker = Process(target=ppd.RPC.serve, args=(node_id, nodes_def)) worker.start() workers.append(worker) - for worker in workers: - worker.join() + try: + for worker in workers: + worker.join() + finally: + # Clean up HTTP servers + for server in http_servers: + server.stop() else: parser.print_help() diff --git a/simple_context_test.py b/simple_context_test.py new file mode 100644 index 00000000..ace598b1 --- /dev/null +++ b/simple_context_test.py @@ -0,0 +1,83 @@ +#!/usr/bin/env python3 +""" +简单测试:验证create_with_channels Context创建和基本功能 +""" + +import threading +import time +import spu.libspu.link as link +from examples.python.advanced.http_channel import HttpChannel + + +def test_create_context_simple(): + """简单测试Context创建""" + + print("=" * 50) + print("简单测试 create_with_channels Context") + print("=" * 50) + + # 测试1: 并行创建所有rank的Context + print("\n1. 并行创建所有rank的Context...") + + world_size = 3 + base_port = 61930 + contexts = [None] * world_size + + def create_context(rank): + """为指定rank创建context""" + print(f"[Rank {rank}] 开始创建Context...") + desc = link.Desc() + desc.recv_timeout_ms = 10000 # 增加超时时间 + + for i in range(world_size): + desc.add_party(f"party_{i}", f"localhost:{base_port + i}") + + channels = [] + for i in range(world_size): + if i == rank: + channels.append(None) + else: + channels.append(HttpChannel(rank, i, base_port + rank, base_port + i)) + + ctx = link.create_with_channels(desc, rank, channels) + print(f"[Rank {rank}] ✅ Context创建成功") + return ctx + + # 使用线程并行创建 + threads = [] + for rank in range(world_size): + thread = threading.Thread( + target=lambda r=rank: contexts.__setitem__(r, create_context(r)) + ) + threads.append(thread) + thread.start() + + # 等待所有线程完成 + for thread in threads: + thread.join() + + print("✅ 所有Context创建完成") + + # 验证所有Context + print("\n2. 验证所有Context...") + for i, ctx in enumerate(contexts): + print(f" Rank {i}: rank={ctx.rank}, world_size={ctx.world_size}") + + # 测试 Send/Recv + print("\n3. 测试 Send/Recv...") + + # 测试通信 + print("\n3. 测试通信...") + + # 测试1: 简单点对点 + print("Rank 0 -> Rank 1...") + data = "hello from rank 0" + contexts[0].send(1, data) + received = contexts[1].recv(0) + print(f"✅ 收到: {received}") + + print("\n✅ 所有测试通过!") + + +if __name__ == "__main__": + test_create_context_simple() diff --git a/spu/BUILD.bazel b/spu/BUILD.bazel index df11c0b4..a44e9d38 100644 --- a/spu/BUILD.bazel +++ b/spu/BUILD.bazel @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -load("@pybind11_bazel//:build_defs.bzl", "pybind_extension") +load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library") package(default_visibility = ["//visibility:public"]) @@ -37,6 +37,8 @@ pybind_extension( }), deps = [ ":exported_symbols.lds", + ":pychannel", + ":pybind_caster", ":version_script.lds", "@spulib//libspu:version", "@spulib//libspu/compiler:compile", @@ -68,3 +70,21 @@ pybind_extension( "@yacl//yacl/link", ], ) + +pybind_library( + name = "pychannel", + hdrs = ["pychannel.h"], + visibility = ["//visibility:private"], + deps = [ + "@yacl//yacl/link/transport:channel", + ], +) + +pybind_library( + name = "pybind_caster", + hdrs = ["pybind_caster.h"], + visibility = ["//visibility:private"], + deps = [ + "@yacl//yacl/base:buffer", + ], +) diff --git a/spu/libspu.cc b/spu/libspu.cc index 9144bc31..93229aec 100644 --- a/spu/libspu.cc +++ b/spu/libspu.cc @@ -13,13 +13,20 @@ // limitations under the License. #include +#include #include +#include #include +#include #include "pybind11/iostream.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h" +#include "spdlog/spdlog.h" +#include "spu/pybind_caster.h" +#include "spu/pychannel.h" +#include "yacl/base/exception.h" #include "yacl/link/algorithm/allgather.h" #include "yacl/link/algorithm/barrier.h" #include "yacl/link/algorithm/broadcast.h" @@ -40,19 +47,17 @@ #include "libspu/spu.h" #include "libspu/version.h" +// Add missing includes for brpc and fmt +#include "butil/macros.h" +#include "fmt/format.h" +#include "gflags/gflags.h" + #ifdef CHECK_AVX #include "cpu_features/cpuinfo_x86.h" #endif namespace py = pybind11; -namespace brpc { - -DECLARE_uint64(max_body_size); -DECLARE_int64(socket_max_unwritten_bytes); - -} // namespace brpc - namespace spu { namespace { @@ -67,6 +72,51 @@ namespace { } // namespace +// Convert yacl::Buffer to py::array_t using zero-copy with move semantics +py::array_t BufferToArray(yacl::Buffer buffer) { + // Create a capsule that takes ownership of the buffer data using move + // semantics This achieves zero-copy by transferring ownership instead of + // copying + auto* buf_ptr = new yacl::Buffer(std::move(buffer)); + py::capsule capsule( + buf_ptr, [](void* data) { delete static_cast(data); }); + + // Create numpy array as a view of the buffer data without copying + return py::array_t({buf_ptr->size()}, // shape + {1}, // strides + buf_ptr->data(), // data pointer + capsule); // capsule to manage lifetime +} + +// Generic buffer to numpy array conversion with type information (zero-copy) +template +py::array_t BufferToTypedArray(yacl::Buffer buffer) { + // Calculate number of elements + size_t num_elements = buffer.size() / sizeof(T); + SPU_ENFORCE(buffer.size() % sizeof(T) == 0, + "Buffer size {} is not divisible by sizeof({})", buffer.size(), + sizeof(T)); + + // Create a capsule that takes ownership of the buffer using move semantics + auto* buf_ptr = new yacl::Buffer(std::move(buffer)); + py::capsule capsule( + buf_ptr, [](void* data) { delete static_cast(data); }); + + // Create typed numpy array that shares the buffer memory + return py::array_t( + {num_elements}, // shape + {1}, // strides + reinterpret_cast(buf_ptr->data()), // data pointer + capsule); // capsule to manage lifetime +} + +// Specialization for uint8_t (most common case) +template <> +py::array_t BufferToTypedArray(yacl::Buffer buffer) { + return BufferToArray( + std::move(buffer)); // Use the existing zero-copy implementation +} + #define NO_GIL py::call_guard() void BindLink(py::module& m) { @@ -76,6 +126,7 @@ void BindLink(py::module& m) { using yacl::link::RetryOptions; using yacl::link::SSLOptions; using yacl::link::VerifyOptions; + using yacl::link::transport::IChannel; // TODO(jint) expose this tag to python? constexpr char PY_CALL_TAG[] = "PY_CALL"; @@ -84,6 +135,23 @@ void BindLink(py::module& m) { SPU Link Library )pbdoc"; + py::class_, PyChannel>(m, "IChannel") + .def(py::init<>()) + .def("send_async", + static_cast( + &IChannel::SendAsync)) + .def("send_async_throttled", + static_cast( + &IChannel::SendAsyncThrottled)) + .def("send", &IChannel::Send) + .def("recv", &IChannel::Recv) + .def("test_send", &IChannel::TestSend) + .def("test_recv", &IChannel::TestRecv) + .def("set_throttle_window_size", &IChannel::SetThrottleWindowSize) + .def("set_chunk_parallel_send_size", &IChannel::SetChunkParallelSendSize) + .def("wait_link_task_finish", &IChannel::WaitLinkTaskFinish) + .def("abort", &IChannel::Abort); + py::class_(m, "CertInfo", "The config info used for certificate") .def_readwrite("certificate_path", &CertInfo::certificate_path, "certificate file path") @@ -281,9 +349,19 @@ void BindLink(py::module& m) { [](const ContextDesc& desc, size_t self_rank, bool log_details) -> std::shared_ptr { py::gil_scoped_release release; - brpc::FLAGS_max_body_size = std::numeric_limits::max(); - brpc::FLAGS_socket_max_unwritten_bytes = - std::numeric_limits::max() / 2; + // Use gflags to set the flags properly + if (google::CommandLineFlagInfo info; + google::GetCommandLineFlagInfo("max_body_size", &info)) { + google::SetCommandLineOption( + "max_body_size", + std::to_string(std::numeric_limits::max()).c_str()); + } + if (google::CommandLineFlagInfo info; google::GetCommandLineFlagInfo( + "socket_max_unwritten_bytes", &info)) { + google::SetCommandLineOption( + "socket_max_unwritten_bytes", + std::to_string(std::numeric_limits::max() / 2).c_str()); + } auto ctx = yacl::link::FactoryBrpc().CreateContext(desc, self_rank); ctx->ConnectToMesh(log_details ? spdlog::level::info @@ -302,6 +380,17 @@ void BindLink(py::module& m) { ctx->ConnectToMesh(); return ctx; }); + m.def( + "create_with_channels", + [](const ContextDesc& desc, size_t self_rank, + std::vector> channels) { + py::gil_scoped_release release; + auto ctx = std::make_shared( + desc, self_rank, std::move(channels), nullptr, false); + ctx->ConnectToMesh(); + return ctx; + }, + py::arg("desc"), py::arg("self_rank"), py::arg("channels")); } struct PyBindShare { @@ -1092,6 +1181,10 @@ PYBIND11_MODULE(libspu, m) { }); m.def("_get_version", []() { return spu::getVersionStr(); }); + + // Expose buffer conversion functions + m.def("buffer_to_array", &BufferToArray, + "Convert yacl::Buffer to numpy array", py::arg("buffer")); } } // namespace spu diff --git a/spu/pybind_caster.h b/spu/pybind_caster.h new file mode 100644 index 00000000..ab793b5d --- /dev/null +++ b/spu/pybind_caster.h @@ -0,0 +1,88 @@ +#pragma once + +#include +#include + +#include "yacl/base/buffer.h" +#include "yacl/base/byte_container_view.h" + +namespace pybind11 { +namespace detail { + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(yacl::Buffer, const_name("bytes")); + + bool load(handle src, bool convert) { + // Try to load from bytes + if (isinstance(src)) { + std::string_view s = src.cast(); + value = yacl::Buffer(s.data(), s.size()); + return true; + } + + // If conversion is allowed, also support numpy array conversion + if (convert && isinstance(src)) { + auto arr = array_t < uint8_t, + array::c_style | array::forcecast > ::ensure(src); + if (!arr) { + return false; + } + value = yacl::Buffer(arr.data(), arr.nbytes()); + return true; + } + + return false; + } + + static handle cast(yacl::Buffer&& src, return_value_policy, handle) { + // Create Python bytes object directly from buffer data + return bytes(static_cast(src.data()), src.size()).release(); + } +}; + +template <> +struct type_caster { + public: + PYBIND11_TYPE_CASTER(yacl::ByteContainerView, const_name("bytes")); + + bool load(handle src, bool convert) { + // Try to load from bytes + if (isinstance(src)) { + std::string_view s = src.cast(); + value = yacl::ByteContainerView(s.data(), s.size()); + return true; + } + + // If conversion is allowed, also support numpy array conversion + if (convert && isinstance(src)) { + auto arr = array_t < uint8_t, + array::c_style | array::forcecast > ::ensure(src); + if (!arr) { + return false; + } + value = yacl::ByteContainerView(arr.data(), arr.nbytes()); + return true; + } + + return false; + } + + static handle cast(yacl::ByteContainerView&& src, return_value_policy, + handle) { + // Create Python bytes object directly from ByteContainerView data + return bytes(reinterpret_cast(src.data()), src.size()) + .release(); + } + + static handle cast(const yacl::ByteContainerView& src, return_value_policy, + handle) { + // Create Python bytes object directly from ByteContainerView data + return bytes(reinterpret_cast(src.data()), src.size()) + .release(); + } +}; + +} // namespace detail +} // namespace pybind11 \ No newline at end of file diff --git a/spu/pychannel.h b/spu/pychannel.h new file mode 100644 index 00000000..936a66ee --- /dev/null +++ b/spu/pychannel.h @@ -0,0 +1,81 @@ +// Copyright 2025 Ant Group Co., Ltd. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "pybind11/pybind11.h" +#include "yacl/link/transport/channel.h" + +class PyChannel : public yacl::link::transport::IChannel { + public: + using yacl::link::transport::IChannel::IChannel; + + void SendAsync(const std::string& key, yacl::Buffer buf) override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, SendAsync, + key, std::move(buf)); + } + + void SendAsyncThrottled(const std::string& key, yacl::Buffer buf) override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, + SendAsyncThrottled, key, std::move(buf)); + } + + void Send(const std::string& key, yacl::ByteContainerView value) override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, Send, key, + value); + } + + yacl::Buffer Recv(const std::string& key) override { + PYBIND11_OVERRIDE_PURE(yacl::Buffer, yacl::link::transport::IChannel, Recv, + key); + } + + void SetRecvTimeout(uint64_t timeout_ms) override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, + SetRecvTimeout, timeout_ms); + } + + uint64_t GetRecvTimeout() const override { + PYBIND11_OVERRIDE_PURE(uint64_t, yacl::link::transport::IChannel, + GetRecvTimeout); + } + + virtual void WaitLinkTaskFinish() override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, + WaitLinkTaskFinish); + } + + virtual void Abort() override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, Abort); + } + + virtual void SetThrottleWindowSize(size_t size) override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, + SetThrottleWindowSize, size); + } + + void TestSend(uint32_t timeout) override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, TestSend, + timeout); + } + + void TestRecv() override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, TestRecv); + } + + void SetChunkParallelSendSize(size_t size) override { + PYBIND11_OVERRIDE_PURE(void, yacl::link::transport::IChannel, + SetChunkParallelSendSize, size); + } +}; \ No newline at end of file diff --git a/src/MODULE.bazel.lock b/src/MODULE.bazel.lock index c7fbcaa8..43addd43 100644 --- a/src/MODULE.bazel.lock +++ b/src/MODULE.bazel.lock @@ -27,8 +27,9 @@ "https://bcr.bazel.build/modules/bazel_features/1.19.0/MODULE.bazel": "59adcdf28230d220f0067b1f435b8537dd033bfff8db21335ef9217919c7fb58", "https://bcr.bazel.build/modules/bazel_features/1.20.0/MODULE.bazel": "8b85300b9c8594752e0721a37210e34879d23adc219ed9dc8f4104a4a1750920", "https://bcr.bazel.build/modules/bazel_features/1.21.0/MODULE.bazel": "675642261665d8eea09989aa3b8afb5c37627f1be178382c320d1b46afba5e3b", - "https://bcr.bazel.build/modules/bazel_features/1.21.0/source.json": "3e8379efaaef53ce35b7b8ba419df829315a880cb0a030e5bb45c96d6d5ecb5f", "https://bcr.bazel.build/modules/bazel_features/1.3.0/MODULE.bazel": "cdcafe83ec318cda34e02948e81d790aab8df7a929cec6f6969f13a489ccecd9", + "https://bcr.bazel.build/modules/bazel_features/1.35.0/MODULE.bazel": "3d9393e5317df8afcfc509458591874ea734fa68ecbdd64fbd6c2c0cbe399526", + "https://bcr.bazel.build/modules/bazel_features/1.35.0/source.json": "c61e98cb3573ce0b8d69eb77c652ab10545375e387e45005e7f8e84792472b09", "https://bcr.bazel.build/modules/bazel_features/1.4.1/MODULE.bazel": "e45b6bb2350aff3e442ae1111c555e27eac1d915e77775f6fdc4b351b758b5d7", "https://bcr.bazel.build/modules/bazel_features/1.9.1/MODULE.bazel": "8f679097876a9b609ad1f60249c49d68bfab783dd9be012faf9d82547b14815a", "https://bcr.bazel.build/modules/bazel_skylib/1.0.3/MODULE.bazel": "bcb0fd896384802d1ad283b4e4eb4d718eebd8cb820b0a2c3a347fb971afd9d8", @@ -143,13 +144,14 @@ "https://bcr.bazel.build/modules/opentracing-cpp/1.6.0/source.json": "da1cb1add160f5e5074b7272e9db6fd8f1b3336c15032cd0a653af9d2f484aed", "https://bcr.bazel.build/modules/platforms/0.0.10/MODULE.bazel": "8cb8efaf200bdeb2150d93e162c40f388529a25852b332cec879373771e48ed5", "https://bcr.bazel.build/modules/platforms/0.0.11/MODULE.bazel": "0daefc49732e227caa8bfa834d65dc52e8cc18a2faf80df25e8caea151a9413f", - "https://bcr.bazel.build/modules/platforms/0.0.11/source.json": "f7e188b79ebedebfe75e9e1d098b8845226c7992b307e28e1496f23112e8fc29", "https://bcr.bazel.build/modules/platforms/0.0.4/MODULE.bazel": "9b328e31ee156f53f3c416a64f8491f7eb731742655a47c9eec4703a71644aee", "https://bcr.bazel.build/modules/platforms/0.0.5/MODULE.bazel": "5733b54ea419d5eaf7997054bb55f6a1d0b5ff8aedf0176fef9eea44f3acda37", "https://bcr.bazel.build/modules/platforms/0.0.6/MODULE.bazel": "ad6eeef431dc52aefd2d77ed20a4b353f8ebf0f4ecdd26a807d2da5aa8cd0615", "https://bcr.bazel.build/modules/platforms/0.0.7/MODULE.bazel": "72fd4a0ede9ee5c021f6a8dd92b503e089f46c227ba2813ff183b71616034814", "https://bcr.bazel.build/modules/platforms/0.0.8/MODULE.bazel": "9f142c03e348f6d263719f5074b21ef3adf0b139ee4c5133e2aa35664da9eb2d", "https://bcr.bazel.build/modules/platforms/0.0.9/MODULE.bazel": "4a87a60c927b56ddd67db50c89acaa62f4ce2a1d2149ccb63ffd871d5ce29ebc", + "https://bcr.bazel.build/modules/platforms/1.0.0/MODULE.bazel": "f05feb42b48f1b3c225e4ccf351f367be0371411a803198ec34a389fb22aa580", + "https://bcr.bazel.build/modules/platforms/1.0.0/source.json": "f4ff1fd412e0246fd38c82328eb209130ead81d62dcd5a9e40910f867f733d96", "https://bcr.bazel.build/modules/prometheus-cpp/1.2.4/MODULE.bazel": "0fbe5dcff66311947a3f6b86ebc6a6d9328e31a28413ca864debc4a043f371e5", "https://bcr.bazel.build/modules/prometheus-cpp/1.2.4/source.json": "aa58bb10d0bb0dcaf4ad2c509ddcec23d2e94c3935e21517a5adbc2363248a55", "https://bcr.bazel.build/modules/protobuf/27.3/MODULE.bazel": "d94898cbf9d6d25c0edca2521211413506b68a109a6b01776832ed25154d23d7", @@ -288,6 +290,7 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.20.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.21.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.3.0/MODULE.bazel": "not found", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.35.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.4.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_features/1.9.1/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/bazel_skylib/1.0.3/MODULE.bazel": "not found", @@ -372,6 +375,8 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/leveldb/1.23/source.json": "82a078a44ec4a6c299fe108e87b6d7a0ce56dd46206bfc69495001087d7e2dfb", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/lib25519/20240321/MODULE.bazel": "849ae135d5582105a552331791eb34bb881bc9cc14106c159c4bf70bed02808a", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/lib25519/20240321/source.json": "17672e2c227edd4b9a17fcc1d078fd94ae86ad1f0c78fc491880bf03565045ed", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/liboqs/0.13.0/MODULE.bazel": "a783966a6e3a205cde9e3dd7cd386cf8c60a4d57f346d1cbdeee8964da2fdd45", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/liboqs/0.13.0/source.json": "f98ef255946fe020308176f1fa43455e7290475ae741804c5ab57484621f18a0", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/libpfm/4.11.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/libsodium/1.0.18/MODULE.bazel": "0c5efe7944f6cf929c6b6414b539ceec47305cfc69856b1c7c9dc066016d50c5", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/libsodium/1.0.18/source.json": "85abb5d41e1f38b3909f24a92eca547d337bee936649d39461037b0582ec6cd1", @@ -402,6 +407,7 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.7/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.8/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/0.0.9/MODULE.bazel": "not found", + "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/platforms/1.0.0/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/prometheus-cpp/1.2.4/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/protobuf/27.3/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/pybind11_bazel/2.11.1.bzl.1/MODULE.bazel": "not found", @@ -502,8 +508,6 @@ "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/upb/0.0.0-20230907-e7430e6/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/xla/20240814.0-64bdcc5/MODULE.bazel": "5348a789c31fdf43215efa633e471bf2851ff5c5044ac2a7fd04da3a3746b72d", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/xla/20240814.0-64bdcc5/source.json": "2045ff3085f2e0ea4e391d7bde507fa0689256de98bc748370e5a5c301c1cba3", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/yacl/0.4.5b10-nightly-20250110/MODULE.bazel": "984e37b3d6d982edb1430f6dc100607569f78378bc3b662ac03802671d0b7c7a", - "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/yacl/0.4.5b10-nightly-20250110/source.json": "8ca56ae0ba48aefadf631c303ef3e0c3bb3cc2f6e0e717264a32b633c5dc4b05", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/zlib/1.2.11/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/zlib/1.2.13/MODULE.bazel": "not found", "https://raw.githubusercontent.com/secretflow/bazel-registry/main/modules/zlib/1.3.1.bcr.1/MODULE.bazel": "not found", @@ -517,7 +521,7 @@ "//bazel:defs.bzl%non_module_dependencies": { "general": { "bzlTransitiveDigest": "JT8ZLEUdrYXN19gijrHtztFq/cEAhJlRlNjhtQUlDIE=", - "usagesDigest": "+b2uLT9tSjfn3mYJ5cnCFUUpexWFRHgQaF1YrXxS54M=", + "usagesDigest": "lIH0VKi5I8BfIaKz+ZMwbrlg05gqJF0HxG/OavnlrNc=", "recordedFileInputs": {}, "recordedDirentsInputs": {}, "envVariables": {},