Skip to content

Commit 03ff7e7

Browse files
feat: support uds and use httpx instead of requests (#18)
1 parent 05827aa commit 03ff7e7

File tree

3 files changed

+40
-13
lines changed

3 files changed

+40
-13
lines changed

checkpoint_engine/ps.py

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@
1414
from functools import lru_cache
1515
from typing import TYPE_CHECKING, Annotated, Any, BinaryIO, NamedTuple
1616

17+
import httpx
1718
import numpy as np
18-
import requests
1919
import torch
2020
import torch.distributed as dist
2121
import zmq
@@ -458,9 +458,25 @@ def register_tensor(buffer: torch.Tensor, offset: int, tensor: torch.Tensor):
458458

459459

460460
def request_inference_to_update(
461-
url: str, socket_paths: dict[str, str], timeout: float = 300.0
462-
) -> None:
463-
resp = requests.post(
461+
url: str,
462+
socket_paths: dict[str, str],
463+
timeout: float = 300.0,
464+
uds: str | None = None,
465+
):
466+
"""Send an inference update request to inference server via HTTP or Unix socket.
467+
468+
Args:
469+
url (str): The HTTP URL or request path (e.g., "http://localhost:19730/inference") to send the request to.
470+
socket_paths (dict[str, str]): A dictionary containing device uuid and IPC socket paths for updating weights.
471+
timeout (float, optional): Request timeout in seconds. Defaults to 300.0.
472+
uds (str, optional): Path to a Unix domain socket. If provided, the request
473+
will be sent via the Unix socket instead of HTTP. Defaults to None.
474+
475+
Raises:
476+
httpx.HTTPStatusError: If the response contains an HTTP error status.
477+
httpx.RequestError: If there was an issue while making the request.
478+
"""
479+
resp = httpx.Client(transport=httpx.HTTPTransport(uds=uds)).post(
464480
url,
465481
json={
466482
"method": "update_weights_from_ipc",

examples/update.py

Lines changed: 19 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from contextlib import contextmanager
99
from typing import Literal
1010

11-
import requests
11+
import httpx
1212
import torch
1313
import torch.distributed as dist
1414
from loguru import logger
@@ -25,16 +25,19 @@ def timer(msg: str):
2525
logger.info(f"{msg} duration: {end - start:.2f} seconds")
2626

2727

28-
def check_vllm_ready(endpoint: str, inference_parallel_size: int):
28+
def check_vllm_ready(endpoint: str, inference_parallel_size: int, uds: str | None = None):
2929
if rank != rank // inference_parallel_size * inference_parallel_size:
3030
return
3131
retry_num = 0
32+
transport = None
33+
if uds is not None:
34+
transport = httpx.HTTPTransport(uds=uds)
3235
while True:
3336
try:
34-
response = requests.get(f"{endpoint}/health", timeout=10)
37+
response = httpx.Client(transport=transport).get(f"{endpoint}/health", timeout=10)
3538
response.raise_for_status()
3639
break
37-
except requests.exceptions.RequestException as e:
40+
except (httpx.ConnectError, httpx.HTTPStatusError) as e:
3841
retry_num += 1
3942
logger.warning(f"fail to check vllm ready, retry {retry_num} times, error: {e}")
4043
time.sleep(5)
@@ -67,7 +70,9 @@ def split_tensors(checkpoint_path: str, rank: int, world_size: int) -> dict[str,
6770

6871

6972
def req_inference(
70-
endpoint: str, inference_parallel_size: int
73+
endpoint: str,
74+
inference_parallel_size: int,
75+
uds: str | None = None,
7176
) -> Callable[[list[tuple[str, str]]], None]:
7277
rank = int(os.getenv("RANK", None))
7378
src = rank // inference_parallel_size * inference_parallel_size
@@ -77,6 +82,7 @@ def req_func(socket_paths: list[tuple[str, str]]):
7782
request_inference_to_update(
7883
f"{endpoint}/collective_rpc",
7984
dict(socket_paths[src : src + inference_parallel_size]),
85+
uds=uds,
8086
)
8187

8288
return req_func
@@ -92,10 +98,11 @@ def update_weights(
9298
endpoint: str,
9399
save_metas_file: str | None = None,
94100
update_method: Literal["broadcast", "p2p", "all"] = "broadcast",
101+
uds: str | None = None,
95102
):
96103
ps.register_checkpoint(checkpoint_name, files=checkpoint_files, named_tensors=named_tensors)
97104
ps.init_process_group()
98-
check_vllm_ready(endpoint, inference_parallel_size)
105+
check_vllm_ready(endpoint, inference_parallel_size, uds)
99106
dist.barrier()
100107
with timer("Gather metas"):
101108
ps.gather_metas(checkpoint_name)
@@ -122,12 +129,13 @@ def join(
122129
req_func: Callable[[list[tuple[str, str]]], None],
123130
inference_parallel_size: int,
124131
endpoint: str,
132+
uds: str | None = None,
125133
):
126134
assert load_metas_file, "load_metas_file is required"
127135
with open(load_metas_file, "rb") as f:
128136
metas = pickle.load(f)
129137
ps.init_process_group()
130-
check_vllm_ready(endpoint, inference_parallel_size)
138+
check_vllm_ready(endpoint, inference_parallel_size, uds)
131139
dist.barrier()
132140
with timer("Gather metas before join"):
133141
ps.gather_metas(checkpoint_name)
@@ -148,10 +156,11 @@ def join(
148156
parser.add_argument("--inference-parallel-size", type=int, default=8)
149157
parser.add_argument("--checkpoint-name", type=str, default="my-checkpoint-iter-0")
150158
parser.add_argument("--update-method", type=str, default="broadcast")
159+
parser.add_argument("--uds", type=str, default=None)
151160
args = parser.parse_args()
152161
rank = int(os.getenv("RANK"))
153162
world_size = int(os.getenv("WORLD_SIZE"))
154-
req_func = req_inference(args.endpoint, args.inference_parallel_size)
163+
req_func = req_inference(args.endpoint, args.inference_parallel_size, args.uds)
155164
ps = ParameterServer(auto_pg=True)
156165
if args.load_metas_file:
157166
join(
@@ -161,6 +170,7 @@ def join(
161170
req_func,
162171
args.inference_parallel_size,
163172
args.endpoint,
173+
args.uds,
164174
)
165175
else:
166176
if os.path.exists(os.path.join(args.checkpoint_path, "model.safetensors.index.json")):
@@ -179,5 +189,6 @@ def join(
179189
args.endpoint,
180190
args.save_metas_file,
181191
args.update_method,
192+
args.uds,
182193
)
183194
time.sleep(args.sleep_time)

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ dependencies = [
1414
"uvicorn",
1515
"loguru",
1616
"numpy",
17-
"requests",
17+
"httpx",
1818
]
1919

2020
[project.optional-dependencies]

0 commit comments

Comments
 (0)