88from contextlib import contextmanager
99from typing import Literal
1010
11- import requests
11+ import httpx
1212import torch
1313import torch .distributed as dist
1414from loguru import logger
@@ -25,16 +25,19 @@ def timer(msg: str):
2525 logger .info (f"{ msg } duration: { end - start :.2f} seconds" )
2626
2727
28- def check_vllm_ready (endpoint : str , inference_parallel_size : int ):
28+ def check_vllm_ready (endpoint : str , inference_parallel_size : int , uds : str | None = None ):
2929 if rank != rank // inference_parallel_size * inference_parallel_size :
3030 return
3131 retry_num = 0
32+ transport = None
33+ if uds is not None :
34+ transport = httpx .HTTPTransport (uds = uds )
3235 while True :
3336 try :
34- response = requests .get (f"{ endpoint } /health" , timeout = 10 )
37+ response = httpx . Client ( transport = transport ) .get (f"{ endpoint } /health" , timeout = 10 )
3538 response .raise_for_status ()
3639 break
37- except requests . exceptions . RequestException as e :
40+ except ( httpx . ConnectError , httpx . HTTPStatusError ) as e :
3841 retry_num += 1
3942 logger .warning (f"fail to check vllm ready, retry { retry_num } times, error: { e } " )
4043 time .sleep (5 )
@@ -67,7 +70,9 @@ def split_tensors(checkpoint_path: str, rank: int, world_size: int) -> dict[str,
6770
6871
6972def req_inference (
70- endpoint : str , inference_parallel_size : int
73+ endpoint : str ,
74+ inference_parallel_size : int ,
75+ uds : str | None = None ,
7176) -> Callable [[list [tuple [str , str ]]], None ]:
7277 rank = int (os .getenv ("RANK" , None ))
7378 src = rank // inference_parallel_size * inference_parallel_size
@@ -77,6 +82,7 @@ def req_func(socket_paths: list[tuple[str, str]]):
7782 request_inference_to_update (
7883 f"{ endpoint } /collective_rpc" ,
7984 dict (socket_paths [src : src + inference_parallel_size ]),
85+ uds = uds ,
8086 )
8187
8288 return req_func
@@ -92,10 +98,11 @@ def update_weights(
9298 endpoint : str ,
9399 save_metas_file : str | None = None ,
94100 update_method : Literal ["broadcast" , "p2p" , "all" ] = "broadcast" ,
101+ uds : str | None = None ,
95102):
96103 ps .register_checkpoint (checkpoint_name , files = checkpoint_files , named_tensors = named_tensors )
97104 ps .init_process_group ()
98- check_vllm_ready (endpoint , inference_parallel_size )
105+ check_vllm_ready (endpoint , inference_parallel_size , uds )
99106 dist .barrier ()
100107 with timer ("Gather metas" ):
101108 ps .gather_metas (checkpoint_name )
@@ -122,12 +129,13 @@ def join(
122129 req_func : Callable [[list [tuple [str , str ]]], None ],
123130 inference_parallel_size : int ,
124131 endpoint : str ,
132+ uds : str | None = None ,
125133):
126134 assert load_metas_file , "load_metas_file is required"
127135 with open (load_metas_file , "rb" ) as f :
128136 metas = pickle .load (f )
129137 ps .init_process_group ()
130- check_vllm_ready (endpoint , inference_parallel_size )
138+ check_vllm_ready (endpoint , inference_parallel_size , uds )
131139 dist .barrier ()
132140 with timer ("Gather metas before join" ):
133141 ps .gather_metas (checkpoint_name )
@@ -148,10 +156,11 @@ def join(
148156 parser .add_argument ("--inference-parallel-size" , type = int , default = 8 )
149157 parser .add_argument ("--checkpoint-name" , type = str , default = "my-checkpoint-iter-0" )
150158 parser .add_argument ("--update-method" , type = str , default = "broadcast" )
159+ parser .add_argument ("--uds" , type = str , default = None )
151160 args = parser .parse_args ()
152161 rank = int (os .getenv ("RANK" ))
153162 world_size = int (os .getenv ("WORLD_SIZE" ))
154- req_func = req_inference (args .endpoint , args .inference_parallel_size )
163+ req_func = req_inference (args .endpoint , args .inference_parallel_size , args . uds )
155164 ps = ParameterServer (auto_pg = True )
156165 if args .load_metas_file :
157166 join (
@@ -161,6 +170,7 @@ def join(
161170 req_func ,
162171 args .inference_parallel_size ,
163172 args .endpoint ,
173+ args .uds ,
164174 )
165175 else :
166176 if os .path .exists (os .path .join (args .checkpoint_path , "model.safetensors.index.json" )):
@@ -179,5 +189,6 @@ def join(
179189 args .endpoint ,
180190 args .save_metas_file ,
181191 args .update_method ,
192+ args .uds ,
182193 )
183194 time .sleep (args .sleep_time )
0 commit comments