Skip to content

Commit f5e68d8

Browse files
erictang000claude
andcommitted
[multi-lora-rl] Address review feedback: path safety + legacy lora_name
Two reviewer concerns: 1. Path traversal in os.path.join(base_sync_path, model_id). model_id is server-generated (api.py validates against ID_PATTERN), so this is defense in depth, but route through os.path.basename in both Megatron and FSDP workers so a misformed id can't escape lora_sync_path. Also add _cleanup_lora_sync_subdir on per-adapter delete_model so the per-tenant subdirs don't accumulate as adapters churn. 2. Legacy update_named_weights path didn't carry the adapter name — vllm_engine generated a numeric name from time.time_ns(), making the adapter inaccessible by tenant model_id. Add lora_name to LoraLoadRequest, plumb through both BaseVLLMInferenceEngine variants (sync + async _load_lora_from_disk), and pass lora_name from both worker files in the legacy branch. Empty string preserves the pre-existing single-tenant behavior. All 6 existing multi-LoRA tests still pass. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent 468d231 commit f5e68d8

5 files changed

Lines changed: 67 additions & 14 deletions

File tree

skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,16 @@ async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo")
308308
args=(pickled_init_info,),
309309
)
310310

311-
async def _load_lora_from_disk(self, lora_path: str):
312-
"""Load LoRA adapters from disk using vLLM's native add_lora method."""
311+
async def _load_lora_from_disk(self, lora_path: str, lora_name: str = ""):
312+
"""Load LoRA adapters from disk using vLLM's native add_lora method.
313+
314+
When ``lora_name`` is empty (legacy single-tenant), a numeric name is
315+
generated. Multi-tenant callers pass ``lora_name`` so subsequent
316+
``model=<lora_name>`` sampling routes to the right adapter.
317+
"""
313318
lora_id = int(time.time_ns() % 0x7FFFFFFF)
314-
lora_request = LoRARequest(lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=lora_path)
319+
name = lora_name or f"{lora_id}"
320+
lora_request = LoRARequest(lora_name=name, lora_int_id=lora_id, lora_path=lora_path)
315321
result = self.llm.llm_engine.add_lora(lora_request)
316322
return result
317323

@@ -320,7 +326,7 @@ async def update_named_weights(self, request: WeightUpdateRequest):
320326

321327
# Handle LoRA disk loading request
322328
if isinstance(request, LoraLoadRequest):
323-
return await self._load_lora_from_disk(request.lora_path)
329+
return await self._load_lora_from_disk(request.lora_path, lora_name=request.lora_name)
324330

325331
if not len(request):
326332
raise ValueError("Weight update request must not be empty")
@@ -453,10 +459,16 @@ def _create_ray_prometheus_stat_loggers(self):
453459
)
454460
return None
455461

456-
async def _load_lora_from_disk(self, lora_path: str):
457-
"""Load LoRA adapters from disk using vLLM's native add_lora method."""
462+
async def _load_lora_from_disk(self, lora_path: str, lora_name: str = ""):
463+
"""Load LoRA adapters from disk using vLLM's native add_lora method.
464+
465+
When ``lora_name`` is empty (legacy single-tenant), a numeric name is
466+
generated. Multi-tenant callers pass ``lora_name`` so subsequent
467+
``model=<lora_name>`` sampling routes to the right adapter.
468+
"""
458469
lora_id = int(time.time_ns() % 0x7FFFFFFF)
459-
lora_request = LoRARequest(lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=lora_path)
470+
name = lora_name or f"{lora_id}"
471+
lora_request = LoRARequest(lora_name=name, lora_int_id=lora_id, lora_path=lora_path)
460472
result = await self.llm.add_lora(lora_request)
461473
return result
462474

@@ -539,7 +551,7 @@ async def update_named_weights(self, request: WeightUpdateRequest):
539551

540552
# Check for LoRA disk loading request
541553
if isinstance(request, LoraLoadRequest):
542-
return await self._load_lora_from_disk(request.lora_path)
554+
return await self._load_lora_from_disk(request.lora_path, lora_name=request.lora_name)
543555

544556
if not len(request):
545557
raise ValueError("Weight update request must not be empty")

skyrl/backends/skyrl_train/weight_sync/base.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,12 +47,18 @@ class LoraLoadRequest(WeightUpdateRequest):
4747
from disk rather than transferring weights from training. Unlike other
4848
WeightUpdateRequest subclasses, this doesn't transfer weights - it tells
4949
the inference engine to load LoRA from a path.
50+
51+
``lora_name`` is the name vLLM should register the adapter under and is
52+
what callers later pass as ``model=<lora_name>`` when sampling. Empty
53+
string preserves the legacy single-tenant behavior where the engine
54+
generates a numeric name itself.
5055
"""
5156

5257
names: List[str] = field(default_factory=list)
5358
dtypes: List[str] = field(default_factory=list)
5459
shapes: List[List[int]] = field(default_factory=list)
5560
lora_path: str = ""
61+
lora_name: str = ""
5662

5763

5864
@dataclass

skyrl/backends/skyrl_train/workers/fsdp/fsdp_worker.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ async def _save_lora_adapters_and_sync(
273273
if isinstance(inference_engine_client, RemoteInferenceClient):
274274
await inference_engine_client.load_lora_adapter(lora_name, lora_sync_path)
275275
else:
276-
lora_request = LoraLoadRequest(lora_path=lora_sync_path)
276+
lora_request = LoraLoadRequest(lora_path=lora_sync_path, lora_name=lora_name)
277277
await inference_engine_client.update_named_weights(lora_request)
278278

279279
torch.distributed.barrier()
@@ -302,8 +302,12 @@ async def broadcast_to_inference_engines(
302302
# Multi-tenant: per-adapter subdir + per-adapter vLLM name. Single
303303
# tenant (model_id=None) keeps the legacy single-path behavior.
304304
base_sync_path = self.cfg.policy.model.lora.lora_sync_path
305-
lora_name = model_id if model_id is not None else SKYRL_LORA_ADAPTER_NAME
306-
lora_sync_path = os.path.join(base_sync_path, model_id) if model_id is not None else base_sync_path
305+
# Defense in depth: api.py validates model_id against ID_PATTERN,
306+
# but route everything through basename here so that even an
307+
# internally-misformed id can't escape lora_sync_path.
308+
safe_model_id = os.path.basename(model_id) if model_id is not None else None
309+
lora_name = safe_model_id if safe_model_id else SKYRL_LORA_ADAPTER_NAME
310+
lora_sync_path = os.path.join(base_sync_path, safe_model_id) if safe_model_id else base_sync_path
307311
await self._save_lora_adapters_and_sync(
308312
peft_model, lora_sync_path, inference_engine_client, lora_name=lora_name
309313
)

skyrl/backends/skyrl_train/workers/megatron/megatron_worker.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -898,7 +898,7 @@ async def _save_lora_adapters_and_sync(
898898
if isinstance(inference_engine_client, RemoteInferenceClient):
899899
await inference_engine_client.load_lora_adapter(lora_name, lora_sync_path)
900900
else:
901-
lora_request = LoraLoadRequest(lora_path=lora_sync_path)
901+
lora_request = LoraLoadRequest(lora_path=lora_sync_path, lora_name=lora_name)
902902
await inference_engine_client.update_named_weights(lora_request)
903903

904904
torch.distributed.barrier()
@@ -927,8 +927,12 @@ async def broadcast_to_inference_engines(
927927
# works: model_id is None, we fall back to the legacy single
928928
# adapter name + shared path.
929929
base_sync_path = self.cfg.policy.model.lora.lora_sync_path
930-
lora_name = model_id if model_id is not None else SKYRL_LORA_ADAPTER_NAME
931-
lora_sync_path = os.path.join(base_sync_path, model_id) if model_id is not None else base_sync_path
930+
# Defense in depth: api.py validates model_id against ID_PATTERN,
931+
# but route everything through basename here so that even an
932+
# internally-misformed id can't escape lora_sync_path.
933+
safe_model_id = os.path.basename(model_id) if model_id is not None else None
934+
lora_name = safe_model_id if safe_model_id else SKYRL_LORA_ADAPTER_NAME
935+
lora_sync_path = os.path.join(base_sync_path, safe_model_id) if safe_model_id else base_sync_path
932936
await self._save_lora_adapters_and_sync(lora_sync_path, inference_engine_client, lora_name=lora_name)
933937
else:
934938
# Extract and send weights using the sender created at init time

skyrl/backends/skyrl_train_backend.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import asyncio
44
import io
55
import os
6+
import shutil
67
import tarfile
78
import tempfile
89

@@ -459,6 +460,31 @@ def _create_colocate_pg(self):
459460

460461
return ResolvedPlacementGroup(pg)
461462

463+
def _cleanup_lora_sync_subdir(self, model_id: str) -> None:
464+
"""Remove the per-tenant lora_sync_path subdir written by the worker.
465+
466+
The Megatron / FSDP workers write each adapter's safetensors into
467+
``lora_sync_path/<basename(model_id)>/`` on every save_weights_for_sampler.
468+
Without cleanup these subdirs accumulate as adapters churn. Mirror the
469+
worker's path construction (incl. basename sanitization) to avoid
470+
deleting anything outside the configured base.
471+
"""
472+
try:
473+
base = self._cfg.policy.model.lora.lora_sync_path
474+
except AttributeError:
475+
return
476+
if not base:
477+
return
478+
safe_id = os.path.basename(model_id)
479+
if not safe_id:
480+
return
481+
subdir = os.path.join(base, safe_id)
482+
try:
483+
shutil.rmtree(subdir, ignore_errors=True)
484+
except OSError as e:
485+
# Best-effort cleanup — log but don't propagate.
486+
logger.warning(f"Failed to remove lora_sync subdir {subdir}: {e}")
487+
462488
def delete_model(self, model_id: str) -> None:
463489
role = self._get_role(model_id)
464490

@@ -469,6 +495,7 @@ def delete_model(self, model_id: str) -> None:
469495
if len(self._model_ids_to_role) > 1:
470496
if role == "policy" and self._base_lora_signature is not None:
471497
self._dispatch.delete_adapter("policy", model_id)
498+
self._cleanup_lora_sync_subdir(model_id)
472499
del self._model_ids_to_role[model_id]
473500
self._model_metadata.pop(model_id, None)
474501
logger.info(f"Removed LoRA adapter '{model_id}'")

0 commit comments

Comments
 (0)