Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -215,28 +215,28 @@ async def cleanup_engine(engine, client=None, timeout=30.0):
except (asyncio.TimeoutError, Exception):
pass

sub = getattr(engine, 'model_parallel_num_msgs_subscriber_socket', None)
if sub is not None:
sub.setsockopt(zmq.RCVTIMEO, 1000)

# Close ZMQ communicator sockets to unblock any stuck ranks.
for attr in ('expert_parallel_zmq_communicator', 'world_zmq_communicator'):
comm = getattr(engine, attr, None)
if comm is not None:
comm.close()

task.cancel()
if client is not None:
client.stop_engines()
try:
await asyncio.wait_for(asyncio.shield(task), timeout=5.0)
await asyncio.wait_for(asyncio.shield(task), timeout=timeout)
except (asyncio.TimeoutError, asyncio.CancelledError, Exception):
pass
# Graceful stop failed — fall back to forcible cleanup.
for attr in ('expert_parallel_zmq_communicator', 'world_zmq_communicator'):
comm = getattr(engine, attr, None)
if comm is not None:
comm.close()

for socket in getattr(engine, 'zmq_sockets', []):
if not socket.closed:
socket.close(linger=0)

task.cancel()
try:
await asyncio.wait_for(asyncio.shield(task), timeout=5.0)
except (asyncio.TimeoutError, asyncio.CancelledError, Exception):
pass

if client is not None:
# Walk the coordinator back to RUNNING regardless of its current state
# so the next test starts cleanly. Each call is a no-op when the
# coordinator is already in the target state (just logs a warning).
client.resume_engines() # SUSPENDED → PAUSED (no-op otherwise)
client.unpause_engines() # PAUSED → RUNNING (no-op otherwise)
client.stop()


Expand Down Expand Up @@ -357,8 +357,6 @@ async def test_parallel_configs(self, initialize_model_parallel, coordinator):
client = None
try:
if rank == 0:
# Yield so engine loop can run before we block the event loop
# with the client's synchronous connect handshake.
await asyncio.sleep(0)
client = InferenceClient(dp_addr)
client.start()
Expand All @@ -382,7 +380,8 @@ async def test_parallel_configs(self, initialize_model_parallel, coordinator):
@pytest.mark.internal
@pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test")
@pytest.mark.asyncio
async def test_deserialize_flag(self, initialize_model_parallel, coordinator):
@pytest.mark.parametrize("deserialize", [True, False], ids=["deserialize", "raw"])
async def test_deserialize_flag(self, initialize_model_parallel, coordinator, deserialize):
"""Test that the correct response type is returned based on the deserialize flag."""
dp_addr = coordinator
port = int(dp_addr.rsplit(":", 1)[-1])
Expand All @@ -393,6 +392,7 @@ async def test_deserialize_flag(self, initialize_model_parallel, coordinator):
inference_coordinator_port=port, launch_inference_coordinator=False
)

# Ensure all engines are registered before submitting requests.
await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier), timeout=30.0
)
Expand All @@ -401,28 +401,18 @@ async def test_deserialize_flag(self, initialize_model_parallel, coordinator):
try:
if torch.distributed.get_rank() == 0:
await asyncio.sleep(0)
# Test deserialize=True
client = InferenceClient(dp_addr, deserialize=True)
client = InferenceClient(dp_addr, deserialize=deserialize)
client.start()
futures = [
client.add_request(prompt=prompt, sampling_params=params)
for prompt, params in requests
]
results = await asyncio.wait_for(asyncio.gather(*futures), timeout=10.0)
for result in results:
assert isinstance(result, DynamicInferenceRequest)
client.stop()

# Test deserialize=False (default)
client = InferenceClient(dp_addr)
client.start()
futures = [
client.add_request(prompt=prompt, sampling_params=params)
for prompt, params in requests
]
results = await asyncio.wait_for(asyncio.gather(*futures), timeout=10.0)
for result in results:
assert isinstance(result, dict)
if deserialize:
assert isinstance(result, DynamicInferenceRequest)
else:
assert isinstance(result, dict)

await asyncio.wait_for(
asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier),
Expand Down Expand Up @@ -624,15 +614,13 @@ def assert_state(eng, expected):
for f in doomed_futures:
assert f.cancelled(), "Client futures should be cancelled after client.stop()"

@pytest.mark.flaky
@pytest.mark.flaky_in_dev
@pytest.mark.internal
@pytest.mark.skipif(not HAVE_ZMQ, reason="pyzmq is required for this test")
@pytest.mark.asyncio
async def test_throughput(self, initialize_model_parallel, coordinator):
"""Throughput benchmark: measures ZMQ packet rate."""
_, dp, _, _, _ = initialize_model_parallel
num_requests = 10**4
num_requests = 10**3
num_iterations = 10

dp_addr = coordinator
Expand Down Expand Up @@ -668,4 +656,4 @@ async def test_throughput(self, initialize_model_parallel, coordinator):
)
await asyncio.get_event_loop().run_in_executor(None, torch.distributed.barrier)
finally:
await cleanup_engine(engine, client)
await cleanup_engine(engine, client, timeout=60.0)
Loading