diff --git a/tests/test_multiple_processes.py b/tests/test_multiple_processes.py new file mode 100644 index 00000000..c8b9f786 --- /dev/null +++ b/tests/test_multiple_processes.py @@ -0,0 +1,88 @@ +import asyncio +import multiprocessing +import random +import sys + +import numpy as np +import pytest + +import ucp + + +def listener(ports): + ucp.init() + + async def _listener(ports): + async def write(ep): + close_msg = bytearray(2) + msg2send = np.arange(10) + msg2recv = np.empty_like(msg2send) + + msgs = [ep.recv(close_msg), ep.send(msg2send), ep.recv(msg2recv)] + await asyncio.gather(*msgs, loop=asyncio.get_event_loop()) + + close_msg = int.from_bytes(close_msg, sys.byteorder) + + if close_msg != 0: + await ep.close() + listeners[close_msg].close() + + listeners = {} + for port in ports: + listeners[port] = ucp.create_listener(write, port=port) + + try: + while not all(listener.closed() for listener in listeners.values()): + await asyncio.sleep(0.1) + except ucp.UCXCloseError: + pass + + asyncio.get_event_loop().run_until_complete(_listener(ports)) + + +def client(listener_ports): + ucp.init() + + async def _client(listener_ports): + async def read(port, close): + close_msg = bytearray(int(port if close else 0).to_bytes(2, sys.byteorder)) + msg2send = np.arange(10) + msg2recv = np.empty_like(msg2send) + + ep = await ucp.create_endpoint(ucp.get_address(), port) + msgs = [ep.send(close_msg), ep.send(msg2send), ep.recv(msg2recv)] + await asyncio.gather(*msgs, loop=asyncio.get_event_loop()) + + close_after = 100 + clients = [] + for i in range(close_after): + for port in listener_ports: + close = i == close_after - 1 + clients.append(read(port, close=close)) + + await asyncio.gather(*clients, loop=asyncio.get_event_loop()) + + asyncio.get_event_loop().run_until_complete(_client(listener_ports)) + + +@pytest.mark.parametrize("num_listeners", [1, 2, 4, 8]) +def test_send_recv_cu(num_listeners): + ports = set() + while len(ports) != num_listeners: + ports = ports.union( + [random.randint(13000, 23000) for n in range(num_listeners)] + ) + ports = list(ports) + + ctx = multiprocessing.get_context("spawn") + listener_process = ctx.Process(name="listener", target=listener, args=[ports]) + client_process = ctx.Process(name="client", target=client, args=[ports]) + + listener_process.start() + client_process.start() + + listener_process.join() + client_process.join() + + assert listener_process.exitcode == 0 + assert client_process.exitcode == 0 diff --git a/tests/test_multiple_nodes.py b/tests/test_single_process.py similarity index 69% rename from tests/test_multiple_nodes.py rename to tests/test_single_process.py index 5cb7e875..e512c6e9 100644 --- a/tests/test_multiple_nodes.py +++ b/tests/test_single_process.py @@ -29,33 +29,22 @@ async def client_node(port): @pytest.mark.asyncio -async def test_multiple_nodes(): - lf1 = ucp.create_listener(server_node) - lf2 = ucp.create_listener(server_node) - assert lf1.port != lf2.port - - nodes = [] - for _ in range(10): - nodes.append(client_node(lf1.port)) - nodes.append(client_node(lf2.port)) - await asyncio.gather(*nodes, loop=asyncio.get_event_loop()) - - -@pytest.mark.asyncio -async def test_one_server_many_clients(): +async def test_one_listener_many_clients(): lf = ucp.create_listener(server_node) clients = [] - for _ in range(100): + for _ in range(50): clients.append(client_node(lf.port)) await asyncio.gather(*clients, loop=asyncio.get_event_loop()) @pytest.mark.asyncio -async def test_two_servers_many_clients(): +async def test_two_listeners_many_clients(): lf1 = ucp.create_listener(server_node) lf2 = ucp.create_listener(server_node) + assert lf1.port != lf2.port + clients = [] - for _ in range(100): + for _ in range(25): clients.append(client_node(lf1.port)) clients.append(client_node(lf2.port)) await asyncio.gather(*clients, loop=asyncio.get_event_loop())