Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,12 @@ async def teardown(self):
async def reset_prefix_cache(self):
return await self.inference_engine_actor.reset_prefix_cache.remote()

async def start_profile(self, profile_prefix: str | None = None):
return await self.inference_engine_actor.start_profile.remote(profile_prefix=profile_prefix)

async def stop_profile(self):
return await self.inference_engine_actor.stop_profile.remote()

async def chat_completion(self, request_payload: Dict[str, Any]) -> Dict[str, Any]:
return await self.inference_engine_actor.chat_completion.remote(request_payload)

Expand Down
32 changes: 24 additions & 8 deletions skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,9 @@ class AsyncVLLMInferenceEngine(BaseVLLMInferenceEngine):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self._weight_loader = VLLMWeightLoader(self.llm, is_async=True)
# vLLM raises if profile() is called without profiler_config; gate on it.
self._profile_enabled = self.llm.vllm_config.profiler_config.profiler is not None
self._profile_counter = 0
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Since AsyncVLLMInferenceEngine is an asynchronous engine, multiple generate calls can run concurrently. Because vLLM's profiler is global, concurrent profiling attempts will conflict. We should initialize a lock here to serialize profiling in the generate method.

Suggested change
self._profile_counter = 0
self._profile_counter = 0
self._profile_lock = asyncio.Lock()


def _create_engine(self, *args, **kwargs):
openai_kwargs = pop_openai_kwargs(kwargs)
Expand Down Expand Up @@ -489,17 +492,30 @@ async def generate(self, input_batch: InferenceEngineInput) -> InferenceEngineOu
"""Generate responses using vLLM's async engine."""
prompt_token_ids, sampling_params = self._preprocess_prompts(input_batch)

tasks = []
for prompt in prompt_token_ids:
# Schedule the collection of outputs for each prompt.
# Avoid duplicate request_ids
request_id = str(uuid4().hex)
task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params))
tasks.append(task)
outputs = await asyncio.gather(*tasks)
if self._profile_enabled:
await self.llm.start_profile(profile_prefix=f"sample_{self._profile_counter}")
self._profile_counter += 1
try:
tasks = []
for prompt in prompt_token_ids:
# Schedule the collection of outputs for each prompt.
# Avoid duplicate request_ids
request_id = str(uuid4().hex)
task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params))
tasks.append(task)
outputs = await asyncio.gather(*tasks)
finally:
if self._profile_enabled:
await self.llm.stop_profile()
Comment on lines +495 to +509
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

As mentioned in the __init__ comment, vLLM's profiler is global and does not support concurrent profiling spans. If multiple generate calls occur simultaneously while profiling is enabled, they will interfere with each other. This block should be protected by a lock to ensure that only one profiling session is active at a time.

Suggested change
if self._profile_enabled:
await self.llm.start_profile(profile_prefix=f"sample_{self._profile_counter}")
self._profile_counter += 1
try:
tasks = []
for prompt in prompt_token_ids:
# Schedule the collection of outputs for each prompt.
# Avoid duplicate request_ids
request_id = str(uuid4().hex)
task = asyncio.create_task(self._collect_outputs(prompt, request_id, sampling_params))
tasks.append(task)
outputs = await asyncio.gather(*tasks)
finally:
if self._profile_enabled:
await self.llm.stop_profile()
async def _do_generate():
tasks = [
asyncio.create_task(self._collect_outputs(p, str(uuid4().hex), sampling_params))
for p in prompt_token_ids
]
return await asyncio.gather(*tasks)
if self._profile_enabled:
async with self._profile_lock:
await self.llm.start_profile(profile_prefix=f"sample_{self._profile_counter}")
self._profile_counter += 1
try:
outputs = await _do_generate()
finally:
await self.llm.stop_profile()
else:
outputs = await _do_generate()


return self._postprocess_outputs(outputs)

async def start_profile(self, profile_prefix: Optional[str] = None) -> None:
await self.llm.start_profile(profile_prefix=profile_prefix)

async def stop_profile(self) -> None:
await self.llm.stop_profile()

async def wake_up(self, *args: Any, **kwargs: Any):
await self.llm.wake_up(tags=kwargs.get("tags", None))

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -209,12 +209,20 @@ class RemoteInferenceClient:
tokenizer: Optional[Any] = None
"""Optional HF tokenizer for local tokenize/detokenize (avoids HTTP round-trips)."""

profile_each_sample: bool = False
"""If True, hit ``/start_profile`` and ``/stop_profile`` around each ``sample()``
call. Requires the server to have been launched with ``profiler_config`` set
(via ``engine_init_kwargs.profiler_config``). Concurrent ``sample()`` calls
are serialized by a lock so traces don't overlap."""

# Private fields excluded from repr for cleaner output
_session: Optional[aiohttp.ClientSession] = field(default=None, repr=False)
_world_size: Optional[Tuple[int, int]] = field(default=None, repr=False)
_gen_sem: Optional[asyncio.Semaphore] = field(default=None, repr=False)
_detok_sem: Optional[asyncio.Semaphore] = field(default=None, repr=False)
_sem_loop: Optional[asyncio.AbstractEventLoop] = field(default=None, repr=False)
_profile_counter: int = field(default=0, repr=False)
_profile_lock: Optional[asyncio.Lock] = field(default=None, repr=False)

def __post_init__(self):
if self.data_parallel_size <= 0:
Expand Down Expand Up @@ -289,10 +297,14 @@ async def _post(self, url: str, json: Dict[str, Any], headers: Optional[Dict[str
try:
body = await resp.json(content_type=None)
except Exception as e:
text = ""
try:
text = await resp.text()
except Exception:
pass
if 400 <= resp.status < 500:
# Non-JSON client error (e.g. plain text 422 from vllm-router).
# Raise immediately — client errors won't succeed on retry.
text = await resp.text()
raise aiohttp.ClientResponseError(
resp.request_info,
resp.history,
Expand All @@ -301,7 +313,10 @@ async def _post(self, url: str, json: Dict[str, Any], headers: Optional[Dict[str
headers=resp.headers,
)
last_exc = e
logger.debug(f"retry {attempt + 1}/{_DATA_PLANE_RETRIES} for {url=}: {e}")
logger.debug(
f"retry {attempt + 1}/{_DATA_PLANE_RETRIES} for {url=}: "
f"status={resp.status} body={text[:200]!r}: {e}"
)
await asyncio.sleep(1)
continue
raise_for_status(resp, body)
Expand Down Expand Up @@ -612,11 +627,28 @@ async def sample(self, request_payload: SampleRequestPayload) -> SampleResponse:

url = f"{self.proxy_url}/inference/v1/generate"
gen_sem, _ = self._get_semaphores()
if gen_sem is None:
response = await self._post(url, json=payload, headers=headers)
else:

async def _do_post() -> Dict[str, Any]:
if gen_sem is None:
return await self._post(url, json=payload, headers=headers)
async with gen_sem:
response = await self._post(url, json=payload, headers=headers)
return await self._post(url, json=payload, headers=headers)

if self.profile_each_sample:
# start/stop_profile is global per-engine, so serialize concurrent
# samples to keep traces clean.
if self._profile_lock is None:
self._profile_lock = asyncio.Lock()
async with self._profile_lock:
prefix = f"sample_{self._profile_counter}"
self._profile_counter += 1
await self.start_profile(profile_prefix=prefix)
try:
response = await _do_post()
finally:
await self.stop_profile()
else:
response = await _do_post()

# vLLM returns: list[dict[str(token_id) → {"logprob": float, ...}] | None]
result_prompt_logprobs: Optional[List[Optional[float]]] = None
Expand Down Expand Up @@ -880,6 +912,20 @@ async def _call_all_servers(
)
return {url: resp for url, resp in results}

async def start_profile(self, profile_prefix: Optional[str] = None) -> Dict[str, Any]:
"""Open a profiler span on every backend server.

Requires the server to have been launched with ``profiler_config`` set
(otherwise vLLM raises a 500). vLLM's ``/start_profile`` endpoint accepts
``profile_prefix`` as a query param (used as the trace filename prefix).
"""
params = {"profile_prefix": profile_prefix} if profile_prefix else None
return await self._call_all_servers("/start_profile", params=params)

async def stop_profile(self) -> Dict[str, Any]:
"""Close the profiler span on every backend and flush traces to ``torch_profiler_dir``."""
return await self._call_all_servers("/stop_profile")

async def pause(self, mode: Union[PauseMode, str] = PauseMode.KEEP, clear_cache: bool = False) -> Dict[str, Any]:
"""
Pause generation on all backends.
Expand Down
11 changes: 11 additions & 0 deletions skyrl/backends/skyrl_train/inference_servers/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,16 @@ def build_new_inference_client(
active_lora_name = (
_SKYRL_LORA_ADAPTER_NAME if lora_cfg and lora_cfg.rank > 0 and cfg.trainer.strategy != "megatron" else None
)

# Auto-enable per-sample profiling when the user configured a vLLM profiler
# via engine_init_kwargs.profiler_config. Accept both dict (raw user input)
# and ProfilerConfig (post-coercion in build_vllm_cli_args).
profiler_cfg = ie_cfg.engine_init_kwargs.get("profiler_config") if ie_cfg.engine_init_kwargs else None
if isinstance(profiler_cfg, dict):
profile_each_sample = bool(profiler_cfg.get("profiler"))
else:
profile_each_sample = bool(profiler_cfg and getattr(profiler_cfg, "profiler", None))

client = RemoteInferenceClient(
proxy_url=server_setup.proxy_url,
server_urls=server_setup.server_urls,
Expand All @@ -292,6 +302,7 @@ def build_new_inference_client(
uses_lora_weight_sync=_uses_lora_weight_sync(cfg),
data_parallel_size=ie_cfg.data_parallel_size,
tokenizer=tokenizer,
profile_each_sample=profile_each_sample,
)

return client, server_setup
11 changes: 9 additions & 2 deletions skyrl/backends/skyrl_train/inference_servers/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,15 @@ def build_vllm_cli_args(cfg: SkyRLTrainConfig) -> Namespace:
else:
args.enable_lora = False

# Add any extra engine_init_kwargs
engine_kwargs = get_config_as_dict(ie_cfg.engine_init_kwargs)
# Add any extra engine_init_kwargs. Copy so we don't mutate the source
# config (downstream readers in setup.py expect the original shape).
engine_kwargs = dict(get_config_as_dict(ie_cfg.engine_init_kwargs))
# vLLM's API server asserts args.profiler_config is a ProfilerConfig
# instance (not a dict), so coerce here when the user supplies it as a dict.
if isinstance(engine_kwargs.get("profiler_config"), dict):
from vllm.config.profiler import ProfilerConfig

engine_kwargs["profiler_config"] = ProfilerConfig(**engine_kwargs["profiler_config"])
for key, value in engine_kwargs.items():
setattr(args, key, value)

Expand Down
Loading