Skip to content

Commit 333f841

Browse files
test: add HTTP and ray utility tests
1 parent a2a8f82 commit 333f841

File tree

2 files changed

+153
-0
lines changed

2 files changed

+153
-0
lines changed

tests/unit/utils/test_http.py

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import json
8+
import threading
9+
from http.server import BaseHTTPRequestHandler, HTTPServer
10+
from unittest.mock import patch
11+
12+
import aiohttp
13+
import pytest
14+
from aiohttp import web
15+
16+
from matrix.utils.http import fetch_url, fetch_url_sync, post_url
17+
18+
19+
@pytest.mark.asyncio
20+
async def test_fetch_url(unused_tcp_port):
21+
"""Verify async HTTP helper for GET requests."""
22+
23+
async def handle_get(request):
24+
return web.Response(text="hello", status=200)
25+
26+
app = web.Application()
27+
app.router.add_get("/", handle_get)
28+
29+
runner = web.AppRunner(app)
30+
await runner.setup()
31+
port = unused_tcp_port
32+
site = web.TCPSite(runner, "127.0.0.1", port)
33+
await site.start()
34+
35+
try:
36+
url = f"http://127.0.0.1:{port}/"
37+
status, content = await fetch_url(url)
38+
assert status == 200
39+
assert content == "hello"
40+
finally:
41+
await runner.cleanup()
42+
43+
44+
@pytest.mark.asyncio
45+
async def test_fetch_url_handles_errors(unused_tcp_port):
46+
"""Ensure fetch_url gracefully handles network errors."""
47+
48+
status, content = await fetch_url(f"http://127.0.0.1:{unused_tcp_port}")
49+
assert status is None
50+
assert "Unexpected error" in content
51+
52+
53+
@pytest.mark.asyncio
54+
async def test_post_url(unused_tcp_port):
55+
"""Verify async POST helper sends data correctly."""
56+
57+
async def handle_post(request):
58+
payload = await request.json()
59+
return web.json_response(payload)
60+
61+
app = web.Application()
62+
app.router.add_post("/", handle_post)
63+
64+
runner = web.AppRunner(app)
65+
await runner.setup()
66+
port = unused_tcp_port
67+
site = web.TCPSite(runner, "127.0.0.1", port)
68+
await site.start()
69+
70+
try:
71+
async with aiohttp.ClientSession() as session:
72+
status, content = await post_url(
73+
session, f"http://127.0.0.1:{port}/", {"foo": "bar"}
74+
)
75+
assert status == 200
76+
assert json.loads(content) == {"foo": "bar"}
77+
finally:
78+
await runner.cleanup()
79+
80+
81+
@pytest.mark.asyncio
82+
async def test_post_url_handles_errors():
83+
"""Ensure post_url surfaces unexpected errors."""
84+
85+
async with aiohttp.ClientSession() as session:
86+
with patch.object(session, "post", side_effect=Exception("boom")):
87+
status, content = await post_url(session, "http://example.com")
88+
89+
assert status is None
90+
assert "boom" in content
91+
92+
93+
def test_fetch_url_sync(unused_tcp_port):
94+
"""Ensure synchronous fetch works and handles errors."""
95+
96+
class Handler(BaseHTTPRequestHandler):
97+
def do_GET(self):
98+
self.send_response(200)
99+
self.end_headers()
100+
self.wfile.write(b"ok")
101+
102+
def log_message(self, format, *args): # pragma: no cover
103+
pass
104+
105+
port = unused_tcp_port
106+
server = HTTPServer(("127.0.0.1", port), Handler)
107+
thread = threading.Thread(target=server.serve_forever, daemon=True)
108+
thread.start()
109+
110+
try:
111+
status, content = fetch_url_sync(f"http://127.0.0.1:{port}")
112+
assert status == 200
113+
assert content == "ok"
114+
finally:
115+
server.shutdown()
116+
thread.join()
117+
118+
with patch("requests.get") as mock_get:
119+
mock_get.side_effect = Exception("boom")
120+
status, content = fetch_url_sync("http://127.0.0.1:1")
121+
assert status is None
122+
assert "boom" in content

tests/unit/utils/test_ray.py

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the license found in the
5+
# LICENSE file in the root directory of this source tree.
6+
7+
import pytest
8+
9+
from matrix.common.cluster_info import ClusterInfo
10+
from matrix.utils import ray as ray_utils
11+
12+
13+
def test_get_ray_addresses():
14+
info = ClusterInfo(hostname="host", client_server_port=10001, dashboard_port=8265)
15+
assert ray_utils.get_ray_address(info) == "ray://host:10001"
16+
assert ray_utils.get_ray_dashboard_address(info) == "http://host:8265"
17+
18+
19+
def test_status_helpers():
20+
assert ray_utils.status_is_success("RUNNING")
21+
for status in ["DEPLOY_FAILED", "DELETING"]:
22+
assert ray_utils.status_is_failure(status)
23+
for status in ["NOT_STARTED", "DEPLOYING", "UNHEALTHY"]:
24+
assert ray_utils.status_is_pending(status)
25+
26+
for fn in [
27+
ray_utils.status_is_success,
28+
ray_utils.status_is_failure,
29+
ray_utils.status_is_pending,
30+
]:
31+
assert not fn("UNKNOWN")

0 commit comments

Comments
 (0)