Skip to content
Open
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion MODULE.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,14 @@ local_path_override(
path = "src",
)

bazel_dep(name = "psi", version = "0.6.0.dev250507")
bazel_dep(name = "psi", version = "0.6.0.dev250922")
bazel_dep(name = "yacl", version = "0.4.5b10-nightly-20250110")
git_override(
module_name = "yacl",
remote = "https://github.com/secretflow/yacl.git",
commit = "a8af9f85816139f62712e29f0e5421da6f7f1e32",
)

bazel_dep(name = "grpc", version = "1.66.0.bcr.4")

# pin [email protected]
Expand Down
56 changes: 21 additions & 35 deletions MODULE.bazel.lock

Large diffs are not rendered by default.

166 changes: 166 additions & 0 deletions examples/python/advanced/test_channel_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,166 @@
#!/usr/bin/env python3
"""
Test script: Verify Python class implementing IChannel and passing to C++
"""

import spu.libspu.link as link
import spu.libspu as spu
from typing import Dict, Optional
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Optional type is imported but not used in this file. It's good practice to remove unused imports to keep the code clean.

Suggested change
from typing import Dict, Optional
from typing import Dict


print(dir(link))
print(spu.__file__)
Comment on lines +10 to +11
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

These print statements appear to be for debugging purposes. They should be removed before merging to keep the test output clean.



class SimpleChannel(link.IChannel):
"""Simple point-to-point channel implementation"""

def __init__(
self, name: str, storage: Dict[str, bytes], local_rank: int, remote_rank: int
):
super().__init__()
self.name = name
self.local_rank = local_rank
self.remote_rank = remote_rank
# Each channel has its own storage, peer_storage allows sharing between two channels
self.storage = storage
self.recv_timeout = 1000
self.throttle_window_size = 1024
self.chunk_parallel_send_size = 4

def SendAsync(self, key: str, buf: bytes) -> None:
"""Asynchronously send data"""
final_key = f"{key}_{self.local_rank}_{self.remote_rank}"
print(f"[{self.name}] SendAsync: key={final_key}, size={len(buf)}")
self.storage[final_key] = buf

def SendAsyncThrottled(self, key: str, buf: bytes) -> None:
"""Asynchronously send data with throttling"""
final_key = f"{key}_{self.local_rank}_{self.remote_rank}"
print(f"[{self.name}] SendAsyncThrottled: key={final_key}, size={len(buf)}")
self.storage[final_key] = buf

def Send(self, key: str, value: bytes) -> None:
"""Synchronously send data"""
final_key = f"{key}_{self.local_rank}_{self.remote_rank}"
print(f"[{self.name}] Send: key={final_key}, size={len(value)}")
self.storage[final_key] = value

def Recv(self, key: str) -> bytes:
"""Receive data"""
final_key = f"{key}_{self.remote_rank}_{self.local_rank}"
print(f"[{self.name}] Recv: key={final_key}")
return self.storage.get(final_key, b"")
Comment on lines +48 to +52
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The Recv method in SimpleChannel is a mock that doesn't block and doesn't respect the recv_timeout. It immediately returns from the storage dictionary. While this works for the current synchronous tests, it's not a realistic simulation of a channel's behavior and could lead to race conditions or flaky tests if used in asynchronous scenarios. Consider implementing a blocking wait with a timeout to make the mock more robust.


def SetRecvTimeout(self, timeout_ms: int) -> None:
"""Set receive timeout"""
print(f"[{self.name}] SetRecvTimeout: {timeout_ms}ms")
self.recv_timeout = timeout_ms

def GetRecvTimeout(self) -> int:
"""Get receive timeout"""
return self.recv_timeout

def WaitLinkTaskFinish(self) -> None:
"""Wait for link tasks to finish"""
print(f"[{self.name}] WaitLinkTaskFinish")

def Abort(self) -> None:
"""Abort operation"""
print(f"[{self.name}] Abort")

def SetThrottleWindowSize(self, size: int) -> None:
"""Set throttle window size"""
print(f"[{self.name}] SetThrottleWindowSize: {size}")
self.throttle_window_size = size

def TestSend(self, timeout: int) -> None:
"""Test send functionality"""
self.Send("test", b"")

def TestRecv(self) -> None:
"""Test receive functionality"""
self.Recv("test")

def SetChunkParallelSendSize(self, size: int) -> None:
"""Set chunk parallel send size"""
print(f"[{self.name}] SetChunkParallelSendSize: {size}")
self.chunk_parallel_send_size = size


def test_basic_functionality():
"""Test basic functionality"""
print("=== Test Basic Functionality ===")

# Create two channels with shared storage
storage = {}
alice = SimpleChannel("Alice", storage, local_rank=0, remote_rank=1)
bob = SimpleChannel("Bob", storage, local_rank=1, remote_rank=0)

# Test basic communication
alice.Send("test_message", b"hello bob")
received = bob.Recv("test_message")
print(f"Bob received: {received}")

# Reverse communication
bob.Send("response", b"hello alice")
received_back = alice.Recv("response")
print(f"Alice received: {received_back}")


def test_with_create_with_channels():
"""Test create_with_channels interface"""
print("\n=== Test create_with_channels Interface ===")

try:

# For simplicity, each rank uses independent channels
# In real scenarios, these channels would connect via network
storage = {}
channelA2B = SimpleChannel("ChannelA2B", storage, local_rank=0, remote_rank=1)
channelB2A = SimpleChannel("ChannelB2A", storage, local_rank=1, remote_rank=0)

# Create device description
desc = link.Desc()
# Add required parties for the test
desc.add_party("party_0", "127.0.0.1:9000")
desc.add_party("party_1", "127.0.0.1:9001")

# Create device contexts with custom channels
ctxA = link.create_with_channels(desc, 0, [None, channelA2B])
ctxB = link.create_with_channels(desc, 1, [channelB2A, None])

print("✅ create_with_channels interface call successful")

# Test basic context functionality
print(f"Alice context rank: {ctxA.rank}")
print(f"Bob context rank: {ctxB.rank}")
print(f"Alice world size: {ctxA.world_size}")
print(f"Bob world size: {ctxB.world_size}")

# Test context communication interfaces
print("Testing context Send and Recv...")
test_msg = "hello from Alice via context"
ctxA.send(1, test_msg)
received = ctxB.recv(0)
print(f"Bob received via context: {received}")

test_msg_back = "hello from Bob via context"
ctxB.send(0, test_msg_back)
received_back = ctxA.recv(1)
print(f"Alice received via context: {received_back}")

# Test async send
ctxA.send_async(1, "async from Alice")
received_async = ctxB.recv(0)
print(f"Bob received async: {received_async}")

except Exception as e:
print(f"❌ create_with_channels interface call failed: {e}")


if __name__ == "__main__":
print("Starting IChannel integration verification...")

# Run tests
test_basic_functionality()
test_with_create_with_channels()
22 changes: 21 additions & 1 deletion spu/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

load("@pybind11_bazel//:build_defs.bzl", "pybind_extension")
load("@pybind11_bazel//:build_defs.bzl", "pybind_extension", "pybind_library")

package(default_visibility = ["//visibility:public"])

Expand All @@ -37,6 +37,8 @@ pybind_extension(
}),
deps = [
":exported_symbols.lds",
":pychannel",
":pybind_caster",
":version_script.lds",
"@spulib//libspu:version",
"@spulib//libspu/compiler:compile",
Expand Down Expand Up @@ -68,3 +70,21 @@ pybind_extension(
"@yacl//yacl/link",
],
)

pybind_library(
name = "pychannel",
hdrs = ["pychannel.h"],
visibility = ["//visibility:private"],
deps = [
"@yacl//yacl/link/transport:channel",
],
)

pybind_library(
name = "pybind_caster",
hdrs = ["pybind_caster.h"],
visibility = ["//visibility:private"],
deps = [
"@yacl//yacl/base:buffer",
],
)
Loading
Loading