-
Notifications
You must be signed in to change notification settings - Fork 136
[WIP] add pybind for IChannel #1265
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 9 commits
5d7070f
cd91dea
a2e383e
7ef41cf
151246f
9600eab
15bd6c0
40e79f6
cef5863
4dd11fc
c459ce7
42a4843
c0e0912
b2759d0
44f9491
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -36,8 +36,14 @@ local_path_override( | |
path = "src", | ||
) | ||
|
||
bazel_dep(name = "psi", version = "0.6.0.dev250507") | ||
bazel_dep(name = "psi", version = "0.6.0.dev250922") | ||
bazel_dep(name = "yacl", version = "0.4.5b10-nightly-20250110") | ||
git_override( | ||
module_name = "yacl", | ||
remote = "https://github.com/secretflow/yacl.git", | ||
commit = "a8af9f85816139f62712e29f0e5421da6f7f1e32", | ||
) | ||
|
||
bazel_dep(name = "grpc", version = "1.66.0.bcr.4") | ||
|
||
# pin [email protected] | ||
|
Large diffs are not rendered by default.
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,166 @@ | ||
#!/usr/bin/env python3 | ||
""" | ||
Test script: Verify Python class implementing IChannel and passing to C++ | ||
""" | ||
|
||
import spu.libspu.link as link | ||
import spu.libspu as spu | ||
from typing import Dict, Optional | ||
|
||
print(dir(link)) | ||
print(spu.__file__) | ||
Comment on lines
+10
to
+11
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
|
||
|
||
class SimpleChannel(link.IChannel): | ||
"""Simple point-to-point channel implementation""" | ||
|
||
def __init__( | ||
self, name: str, storage: Dict[str, bytes], local_rank: int, remote_rank: int | ||
): | ||
super().__init__() | ||
self.name = name | ||
self.local_rank = local_rank | ||
self.remote_rank = remote_rank | ||
# Each channel has its own storage, peer_storage allows sharing between two channels | ||
self.storage = storage | ||
self.recv_timeout = 1000 | ||
self.throttle_window_size = 1024 | ||
self.chunk_parallel_send_size = 4 | ||
|
||
def SendAsync(self, key: str, buf: bytes) -> None: | ||
"""Asynchronously send data""" | ||
final_key = f"{key}_{self.local_rank}_{self.remote_rank}" | ||
print(f"[{self.name}] SendAsync: key={final_key}, size={len(buf)}") | ||
self.storage[final_key] = buf | ||
|
||
def SendAsyncThrottled(self, key: str, buf: bytes) -> None: | ||
"""Asynchronously send data with throttling""" | ||
final_key = f"{key}_{self.local_rank}_{self.remote_rank}" | ||
print(f"[{self.name}] SendAsyncThrottled: key={final_key}, size={len(buf)}") | ||
self.storage[final_key] = buf | ||
|
||
def Send(self, key: str, value: bytes) -> None: | ||
"""Synchronously send data""" | ||
final_key = f"{key}_{self.local_rank}_{self.remote_rank}" | ||
print(f"[{self.name}] Send: key={final_key}, size={len(value)}") | ||
self.storage[final_key] = value | ||
|
||
def Recv(self, key: str) -> bytes: | ||
"""Receive data""" | ||
final_key = f"{key}_{self.remote_rank}_{self.local_rank}" | ||
print(f"[{self.name}] Recv: key={final_key}") | ||
return self.storage.get(final_key, b"") | ||
Comment on lines
+48
to
+52
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The |
||
|
||
def SetRecvTimeout(self, timeout_ms: int) -> None: | ||
"""Set receive timeout""" | ||
print(f"[{self.name}] SetRecvTimeout: {timeout_ms}ms") | ||
self.recv_timeout = timeout_ms | ||
|
||
def GetRecvTimeout(self) -> int: | ||
"""Get receive timeout""" | ||
return self.recv_timeout | ||
|
||
def WaitLinkTaskFinish(self) -> None: | ||
"""Wait for link tasks to finish""" | ||
print(f"[{self.name}] WaitLinkTaskFinish") | ||
|
||
def Abort(self) -> None: | ||
"""Abort operation""" | ||
print(f"[{self.name}] Abort") | ||
|
||
def SetThrottleWindowSize(self, size: int) -> None: | ||
"""Set throttle window size""" | ||
print(f"[{self.name}] SetThrottleWindowSize: {size}") | ||
self.throttle_window_size = size | ||
|
||
def TestSend(self, timeout: int) -> None: | ||
"""Test send functionality""" | ||
self.Send("test", b"") | ||
|
||
def TestRecv(self) -> None: | ||
"""Test receive functionality""" | ||
self.Recv("test") | ||
|
||
def SetChunkParallelSendSize(self, size: int) -> None: | ||
"""Set chunk parallel send size""" | ||
print(f"[{self.name}] SetChunkParallelSendSize: {size}") | ||
self.chunk_parallel_send_size = size | ||
|
||
|
||
def test_basic_functionality(): | ||
"""Test basic functionality""" | ||
print("=== Test Basic Functionality ===") | ||
|
||
# Create two channels with shared storage | ||
storage = {} | ||
alice = SimpleChannel("Alice", storage, local_rank=0, remote_rank=1) | ||
bob = SimpleChannel("Bob", storage, local_rank=1, remote_rank=0) | ||
|
||
# Test basic communication | ||
alice.Send("test_message", b"hello bob") | ||
received = bob.Recv("test_message") | ||
print(f"Bob received: {received}") | ||
|
||
# Reverse communication | ||
bob.Send("response", b"hello alice") | ||
received_back = alice.Recv("response") | ||
print(f"Alice received: {received_back}") | ||
|
||
|
||
def test_with_create_with_channels(): | ||
"""Test create_with_channels interface""" | ||
print("\n=== Test create_with_channels Interface ===") | ||
|
||
try: | ||
|
||
# For simplicity, each rank uses independent channels | ||
# In real scenarios, these channels would connect via network | ||
storage = {} | ||
channelA2B = SimpleChannel("ChannelA2B", storage, local_rank=0, remote_rank=1) | ||
channelB2A = SimpleChannel("ChannelB2A", storage, local_rank=1, remote_rank=0) | ||
|
||
# Create device description | ||
desc = link.Desc() | ||
# Add required parties for the test | ||
desc.add_party("party_0", "127.0.0.1:9000") | ||
desc.add_party("party_1", "127.0.0.1:9001") | ||
|
||
# Create device contexts with custom channels | ||
ctxA = link.create_with_channels(desc, 0, [None, channelA2B]) | ||
ctxB = link.create_with_channels(desc, 1, [channelB2A, None]) | ||
|
||
print("✅ create_with_channels interface call successful") | ||
|
||
# Test basic context functionality | ||
print(f"Alice context rank: {ctxA.rank}") | ||
print(f"Bob context rank: {ctxB.rank}") | ||
print(f"Alice world size: {ctxA.world_size}") | ||
print(f"Bob world size: {ctxB.world_size}") | ||
|
||
# Test context communication interfaces | ||
print("Testing context Send and Recv...") | ||
test_msg = "hello from Alice via context" | ||
ctxA.send(1, test_msg) | ||
received = ctxB.recv(0) | ||
print(f"Bob received via context: {received}") | ||
|
||
test_msg_back = "hello from Bob via context" | ||
ctxB.send(0, test_msg_back) | ||
received_back = ctxA.recv(1) | ||
print(f"Alice received via context: {received_back}") | ||
|
||
# Test async send | ||
ctxA.send_async(1, "async from Alice") | ||
received_async = ctxB.recv(0) | ||
print(f"Bob received async: {received_async}") | ||
|
||
except Exception as e: | ||
print(f"❌ create_with_channels interface call failed: {e}") | ||
|
||
|
||
if __name__ == "__main__": | ||
print("Starting IChannel integration verification...") | ||
|
||
# Run tests | ||
test_basic_functionality() | ||
test_with_create_with_channels() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The
Optional
type is imported but not used in this file. It's good practice to remove unused imports to keep the code clean.