Skip to content

Commit e4648d4

Browse files
erictang000hao-aaronclaude
authored
[tinker][megatron] Multi-LoRA Megatron + Tinker API RL Training (#1621)
## Summary Multi-tenant LoRA training **and per-adapter sampling** for the SkyRL-Train Megatron + FSDP backends, exposed via the Tinker API. Multiple `rl_loop` clients can run concurrently against a single Tinker server, each owning their own LoRA adapter; the backend swaps the live adapter into Megatron's optimizer/buffer state on demand and registers per-tenant adapters on vLLM by `model_id`. ## What's in the diff Six files, ~238 LOC net. **Backend / dispatch (`skyrl/backends/skyrl_train_backend.py`, `workers/worker_dispatch.py`):** - Drop the v1 single-tenant gates in `sample()` and `save_sampler_checkpoint()`. The batched-sample path now validates every `model_id` in a mixed batch is a known policy, and resolves the inference-engine model name **per request** (`model_id` for LoRA tenants, `resolve_policy_model_name(cfg)` as the FFT / single-tenant fallback). The engine's `find_batchable_sample` legitimately mixes adapters in a single sample call now that multi-LoRA is live. - `WorkerDispatch.save_weights_for_sampler` takes an optional `model_id`. Before broadcasting it calls `ensure_active_adapter("policy", model_id)` to swap the requested adapter into the optimizer/model on every worker — without this we'd export some other tenant's LoRA weights to vLLM. `model_id` is forwarded into `_broadcast_to_inference_engines` so the worker registers the adapter on vLLM under that name. - `model_id=None` everywhere preserves the legacy single-tenant code path. **Per-tenant LoRA save+load (`workers/megatron/megatron_worker.py`, `workers/fsdp/fsdp_worker.py`):** - Both `_save_lora_adapters_and_sync` and `broadcast_to_inference_engines` take an optional `lora_name` / `model_id`. With `model_id` set the adapter writes into a per-tenant subdir of `lora_sync_path` and registers on vLLM under that name (via `RemoteInferenceClient.load_lora_adapter`). With `model_id=None` the legacy `SKYRL_LORA_ADAPTER_NAME` shared-path behavior is unchanged. - Megatron path additionally requires `merge_lora=False` so vLLM serves each tenant's adapter by name rather than baking it into the base weights. **Config knobs (`skyrl/train/config/ppo_base_config.yaml`):** - New `trainer.policy.model.lora.max_loras` (default 1) — maps to vLLM's `max_loras`, the number of adapters that can be active in a single GPU batch. - New `trainer.policy.model.lora.max_cpu_loras` (default null → vLLM defaults to `max_loras`) — the total LoRA capacity in vLLM's CPU LRU cache. **Tests (`tests/tinker/skyrl_train/test_multi_lora_megatron.py`):** - Existing train-only tests (`test_two_adapters_train_independently`, `test_per_adapter_step_isolation`, `test_rank_mismatch_rejected`, `test_delete_then_train_remaining`) continue to pass under the per-tenant config (`merge_lora=False` + `max_loras=4`). - Removed the train-only `test_sample_with_two_adapters_errors` — the gate it covered is gone. - Added `test_per_adapter_sample_isolation`: two pristine adapters following an identical training trajectory must produce bit-identical greedy samples through vLLM at every step. Catches: vLLM serving the wrong adapter for a `model_id`, `load_lora_adapter("B", ...)` clobbering adapter A's slot, optimizer-state regressions that surface as different greedy tokens. - Added `test_two_adapters_sample_independently`: A trains, B trains in between (registering its own LoRA on vLLM), A trains one more step. Asserts (1) A's tokens advance after the resumed step (A's optimizer state survived B's intervention end-to-end through sampling) and (2) A's final tokens differ from B's (B's adapter sync did not clobber A's slot). - New `_sample_greedy(tc, name, tok, prompt)` helper: `save_weights_and_get_sampling_client(name=...)` then a deterministic temperature=0 sample. ## Operator contract To run multi-tenant LoRA on Megatron, set: ```yaml trainer: policy: megatron_config: lora_config: merge_lora: false model: lora: max_loras: <max concurrent adapters> max_cpu_loras: <total adapter capacity> ``` `max_cpu_loras` should be ≥ the number of distinct adapters you expect to serve concurrently, otherwise vLLM evicts an adapter mid-run and the next sample 404s. There is no auto-rehydrate yet — operators size for the workload. The FFT and single-tenant LoRA paths are unchanged; existing configs need no modification. ## Test plan - [x] `pytest tests/tinker/skyrl_train/test_multi_lora_megatron.py` (6 tests, GPU-gated, requires 2 GPUs for Megatron policy + vLLM) - [x] All existing single-tenant `test_api.py` paths - [ ] End-to-end smoke: two `rl_loop` clients with distinct `run_name`s sharing one Tinker server; verify both converge on their respective tasks without contamination --------- Signed-off-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: ahao-anyscale <ahao@anyscale.com> Co-authored-by: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent b418d34 commit e4648d4

16 files changed

Lines changed: 515 additions & 113 deletions

File tree

docs/content/docs/tinker/architecture.mdx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -122,9 +122,11 @@ The Tinker SDK sends a `sampling_session_seq_id` field when using the ephemeral
122122

123123
Persistent saves can be very expensive because they write full model weights to disk on every call. In RL training loops that sync weights every batch, ephemeral mode avoids this overhead entirely. In typical RL loops (e.g., tinker-cookbook's `rl_loop`), every iteration uses ephemeral mode before sampling, and persistent saves are reserved for periodic checkpointing.
124124

125-
### Single Model Constraint
125+
### Multiple LoRA tenants
126126

127-
SkyRL currently supports only one copy of sampling model weights at a time. This differs from Thinking Machines' hosted service that supports arbitrarily many sampling clients attached to various sampling model weights. In SkyRL, after a weight sync, all subsequent `sample()` calls automatically use the updated weights.
127+
On the Megatron backend, SkyRL supports multiple LoRA adapters trained and sampled concurrently against a single server. Each tenant's adapter weights and optimizer state live in pinned-CPU slots; the live GPU adapter is swapped on demand at the top of every per-model dispatch entry point (forward, forward_backward, optim_step, save_weights_for_sampler). On the inference side, vLLM serves each tenant's adapter by `model_id` after `save_weights_for_sampler` registers it via `load_lora_adapter`. See [Multi-tenancy](./multi_tenancy) for the design and operator contract.
128+
129+
Full-parameter fine-tuning and the FSDP backend remain single-tenant — calling `create_model` a second time on those paths returns an error.
128130

129131
## Checkpointing
130132

@@ -163,4 +165,4 @@ Tinker represents training data as `Datum` objects with a `ModelInput` (containi
163165
- **Shifts** tokens: Tinker pre-shifts inputs/targets, but SkyRL-Train shifts internally, so the backend appends the last target token to reconstruct full sequences
164166
- Builds `attention_mask`, `loss_mask`, and `response_mask` tensors from token weights
165167

166-
There is currently a limitation that batch size must be divisible by the data parallelism size (number of GPUs). The engine layer handles batching multiple client requests together before passing them to the backend.
168+
The engine layer also batches multiple client requests together before passing them to the backend.

docs/content/docs/tinker/configuration.mdx

Lines changed: 18 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,24 @@ python -m tinker_cookbook.recipes.sl_loop ... lora_rank=32
5454
python -m tinker_cookbook.recipes.sl_loop ... lora_rank=0
5555
```
5656

57-
No server-side configuration is needed to switch between LoRA and full-parameter fine-tuning.
57+
No server-side configuration is needed to switch between single-tenant LoRA and full-parameter fine-tuning.
58+
59+
### Multi-tenant LoRA
60+
61+
Hosting multiple LoRA tenants concurrently against one server *does* require server-side configuration on the Megatron backend. At minimum:
62+
63+
```json
64+
{
65+
"trainer.placement.colocate_all": false,
66+
"trainer.policy.megatron_config.lora_config.merge_lora": false,
67+
"trainer.policy.model.lora.max_loras": <max concurrent adapters in a single batch>,
68+
"trainer.policy.model.lora.max_cpu_loras": <total adapter capacity>
69+
}
70+
```
71+
72+
`merge_lora: false` is required so vLLM serves each tenant's adapter by name (with `merge_lora: true` vLLM only sees the merged base and per-tenant sampling returns the wrong weights). `max_cpu_loras` must be sized to the peak number of concurrent tenants — there is no on-demand reload, and if vLLM evicts an adapter the next `sample()` against it 404s. All adapters on one server must share the same `(rank, alpha, target_modules)` signature; mismatched signatures are hard-rejected at `create_model`.
73+
74+
See [Multi-tenancy](./multi_tenancy) for the full operator contract and SFT/RL quickstarts.
5875

5976
## Full Config Reference
6077

docs/content/docs/tinker/limitations.mdx

Lines changed: 3 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,11 @@ The Tinker integration is under active development. This page documents current
66

77
## Current Limitations
88

9-
### Single Model
9+
### Multi-tenant LoRA: Megatron only
1010

11-
Only one training model and one set of sampling weights can be loaded at a time. Calling `create_model` when a model already exists will return an error. After a weight sync, all subsequent `sample()` calls use the updated weightsthere is no support for maintaining multiple sampling snapshots concurrently. To switch models, restart the server.
11+
Multi-tenant LoRA training and sampling are supported on the **Megatron** backend with vLLM serving per-tenant adapters by name. See [Multi-tenancy](./multi_tenancy) for the operator contract and SL/RL quickstarts. **FSDP2** support is pending, and full-parameter fine-tuning remains single-tenant on both backendscalling `create_model` with `lora_rank=0` while another model exists returns an error.
1212

13-
### Single-tenant LoRA
14-
Related to the above limitation, even when training with LoRA adaptors, the SkyRL-Train backend only supports one training model and one set of sampling weights. We plan to support training and sampling on multiple LoRA adaptors concurrently in the future.
15-
16-
### Vision Language Models
17-
18-
Vision language models (VLMs) are supported through the Tinker integration. We have validated the path end-to-end on [Qwen3-VL](https://huggingface.co/Qwen/Qwen3-VL-8B-Instruct) — see the [Vision Language cookbook recipe](./cookbook#vision-language-vlm_classifier) for a runnable example. We welcome contributions that extend coverage to additional VLM families.
19-
20-
21-
### Batch Size Constraint
22-
23-
The batch size must be evenly divisible by the data parallelism size (number of GPUs). For example, with 4 GPUs you cannot use a batch size of 5.
13+
All adapters registered against one server must share the same `(rank, alpha, target_modules)` signature; mismatched signatures are hard-rejected at `create_model`.
2414

2515
### No Prompt Logprobs
2616

docs/content/docs/tinker/meta.json

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
"overview",
55
"quickstart",
66
"architecture",
7+
"multi_tenancy",
78
"cookbook",
89
"configuration",
910
"limitations"
Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,168 @@
1+
---
2+
title: "Multi-tenancy"
3+
---
4+
5+
A single SkyRL Tinker server can host multiple LoRA adapters concurrently against a shared base model. Each adapter is its own Tinker `model_id` and its own client session — multiple `tinker-cookbook` recipes can train and sample in parallel without spinning up a separate server per workload.
6+
7+
This page describes the design, the operator contract, and quickstarts for SFT (sl_loop.py) and RL (rl_loop.py).
8+
9+
<Callout type="info">
10+
Multi-tenancy is wired on the **Megatron** backend with vLLM serving per-tenant adapters. FSDP2 multi-tenancy and multi-tenant full-parameter fine-tuning are not yet supported — see [Limitations](./limitations).
11+
</Callout>
12+
13+
## How it works
14+
15+
The base model is loaded once on the policy workers and shared across all tenants. Each tenant gets a per-adapter slot in pinned CPU memory holding its LoRA params, optimizer state, and step count; the live GPU adapter is swapped on demand at the top of every per-model dispatch entry point. Clients never reason about which adapter is currently resident — they just call the Tinker API with their `model_id`.
16+
17+
What this means for you:
18+
19+
- **GPU memory is bounded** by the base model plus a few small LoRA buffers, regardless of tenant count. The growth from adding a tenant is in *CPU* memory (one slot per adapter, on the order of `~3× lora_param_bytes_per_DP_shard` — tens of MB for Qwen3-0.6B at rank 32).
20+
- **Swap cost is small** relative to a forward pass — a host→device `tensor.copy_()` plus a DP-group barrier. You should not see noticeable per-call latency from tenant churn.
21+
- **Per-tenant sampling on vLLM** is by `model_id`. The worker exports each tenant's adapter into `lora_sync_path/<model_id>/` on `save_weights_for_sampler` and registers it on vLLM via `load_lora_adapter`. Sampling uses `model=<model_id>` and vLLM routes to the right adapter.
22+
- **Capacity is bounded by `max_cpu_loras`**, vLLM's CPU LRU cache. If you have more concurrent tenants than slots, vLLM evicts one and the next `sample()` against it 404s — there is no on-demand reload. Size for your peak.
23+
24+
## Operator contract
25+
26+
Required `--backend-config` keys to run multi-tenant LoRA on Megatron:
27+
28+
```json
29+
{
30+
"trainer.placement.colocate_all": false,
31+
"trainer.policy.megatron_config.lora_config.merge_lora": false,
32+
"trainer.policy.model.lora.max_loras": <max concurrent adapters in a single batch>,
33+
"trainer.policy.model.lora.max_cpu_loras": <total adapter capacity>
34+
}
35+
```
36+
37+
All adapters must share the same `(rank, alpha, target_modules)` signature. Mismatches are hard-rejected at `create_model` with a `LoRA signature mismatch …` error.
38+
39+
The first `create_model` on a fresh server triggers the policy build and bootstraps the per-tenant adapter slot infrastructure; subsequent `create_model` calls register additional adapter slots and complete in milliseconds. When the *last* registered model is unloaded the server tears down the Ray runtime via `ray.shutdown()`; the next `create_model` rebuilds it.
40+
41+
## Quickstart — Two SL clients
42+
43+
Run two `tinker-cookbook` `sl_loop` clients in parallel against one Megatron-backed Tinker server.
44+
45+
### 1. Start the server
46+
47+
```bash
48+
uv run --extra tinker --extra megatron -m skyrl.tinker.api \
49+
--host 0.0.0.0 \
50+
--port 8000 \
51+
--base-model Qwen/Qwen3-0.6B \
52+
--backend megatron \
53+
--backend-config '{
54+
"strategy": "megatron",
55+
"trainer.placement.policy_num_gpus_per_node": 1,
56+
"trainer.placement.policy_num_nodes": 1,
57+
"trainer.placement.colocate_all": false,
58+
"trainer.policy.megatron_config.tensor_model_parallel_size": 1,
59+
"trainer.policy.megatron_config.pipeline_model_parallel_size": 1,
60+
"trainer.policy.megatron_config.lora_config.merge_lora": false,
61+
"trainer.policy.model.lora.max_loras": 2,
62+
"trainer.policy.model.lora.max_cpu_loras": 2,
63+
"trainer.logprobs_chunk_size": null
64+
}'
65+
```
66+
67+
Wait for `init policy model done` after the first client connects.
68+
69+
### 2. Run two `sl_loop` clients
70+
71+
In two separate terminals (in the tinker-cookbook repo):
72+
73+
```bash
74+
# Terminal 2 — client A
75+
TINKER_API_KEY=tml-dummy uv run --with tinker --with tinker-cookbook --with datasets \
76+
python -m tinker_cookbook.recipes.sl_loop \
77+
base_url=http://localhost:8000 \
78+
model_name="Qwen/Qwen3-0.6B" \
79+
train_on_what=LAST_ASSISTANT_MESSAGE \
80+
lora_rank=32 \
81+
log_path=/tmp/sl_loop_a.log
82+
```
83+
84+
```bash
85+
# Terminal 3 — client B
86+
TINKER_API_KEY=tml-dummy uv run --with tinker --with tinker-cookbook --with datasets \
87+
python -m tinker_cookbook.recipes.sl_loop \
88+
base_url=http://localhost:8000 \
89+
model_name="Qwen/Qwen3-0.6B" \
90+
train_on_what=LAST_ASSISTANT_MESSAGE \
91+
lora_rank=32 \
92+
log_path=/tmp/sl_loop_b.log
93+
```
94+
95+
Stagger the launches by ~20s so the second client doesn't race the policy build. Both clients **must** use the same `lora_rank` and `model_name`.
96+
97+
You should see both clients converge on their respective tasks, with NLL trending independently downward in both `sl_loop_a.log` and `sl_loop_b.log`.
98+
GPU memory will stay bounded even as the second client connects (single base model + N LoRA slots).
99+
100+
## Quickstart — Two RL clients
101+
102+
Two `rl_loop` clients each train and sample independently against one server. RL exercises the per-tenant `save_weights_for_sampler` + `sample(model=<model_id>)` path.
103+
104+
### 1. Start the server
105+
106+
```bash
107+
uv run --extra tinker --extra megatron -m skyrl.tinker.api \
108+
--host 0.0.0.0 \
109+
--port 8000 \
110+
--base-model Qwen/Qwen3-0.6B \
111+
--backend megatron \
112+
--backend-config '{
113+
"strategy": "megatron",
114+
"trainer.placement.policy_num_gpus_per_node": 4,
115+
"trainer.placement.policy_num_nodes": 1,
116+
"trainer.placement.colocate_all": false,
117+
"trainer.policy.megatron_config.tensor_model_parallel_size": 1,
118+
"trainer.policy.megatron_config.pipeline_model_parallel_size": 1,
119+
"trainer.policy.megatron_config.lora_config.merge_lora": false,
120+
"trainer.micro_train_batch_size_per_gpu": 64,
121+
"trainer.micro_forward_batch_size_per_gpu": 64,
122+
"generator.inference_engine.num_engines": 1,
123+
"generator.inference_engine.tensor_parallel_size": 1,
124+
"trainer.policy.model.lora.max_loras": 2,
125+
"trainer.policy.model.lora.max_cpu_loras": 2,
126+
"trainer.logprobs_chunk_size": null,
127+
}'
128+
```
129+
130+
Critical knobs vs the SL quickstart:
131+
- `colocate_all: false` is required. In order for sampling and training to progress independently for different client calls, inference engines and trainer workers should be placed on different GPUs.
132+
- `merge_lora: false` is required. With `merge_lora: true`, vLLM serves the merged base model and `sample(model=<adapter>)` returns the wrong tenant's weights.
133+
- `max_loras` ≥ number of adapters in a single batch (typically equal to the client count).
134+
- `max_cpu_loras` must be ≥ the number of adapters you expect to serve concurrently. There is no on-demand reload — if vLLM evicts an adapter, its next `sample()` 404s.
135+
136+
### 2. Run two `rl_loop` clients
137+
138+
```bash
139+
# Terminal 2 — client A
140+
TINKER_API_KEY=tml-dummy uv run --with tinker --with tinker-cookbook --with datasets --with torch \
141+
python -m tinker_cookbook.recipes.rl_loop \
142+
base_url=http://localhost:8000 \
143+
model_name="Qwen/Qwen3-0.6B" \
144+
lora_rank=32 \
145+
log_path=/tmp/rl_loop_a.log
146+
```
147+
148+
```bash
149+
# Terminal 3 — client B
150+
TINKER_API_KEY=tml-dummy uv run --with tinker --with tinker-cookbook --with datasets --with torch \
151+
python -m tinker_cookbook.recipes.rl_loop \
152+
base_url=http://localhost:8000 \
153+
model_name="Qwen/Qwen3-0.6B" \
154+
lora_rank=32 \
155+
log_path=/tmp/rl_loop_b.log
156+
```
157+
158+
Stagger by ~20 s. Both clients **must** use the same `lora_rank` and `model_name`.
159+
160+
You should see both clients' rewards trend upward independently in `rl_loop_a.log` and `rl_loop_b.log`, vLLM logs showing two distinct adapter names registered and `sample` requests routed to each., and GPU memory staying bounded (single base model, two LoRA adapters, CPU LRU holds the same two).
161+
162+
## Troubleshooting
163+
164+
- **`LoRA signature mismatch`** — clients passed different `(rank, alpha, target_modules)`. All adapters on one server share a signature, captured from the first `create_model`.
165+
- **`sample()` 404 on `lora_name=…`** — either `save_sampler_checkpoint` wasn't called for that `model_id` before sampling, or `max_cpu_loras` is too low and vLLM evicted the adapter. Check the vLLM server log.
166+
- **Server hangs on the second `create_model`** — the first policy build hasn't finished. Wait for `init policy model done` before starting subsequent clients.
167+
- **CPU OOM on the Nth client** — each adapter slot holds LoRA params + fp32 main + Adam moments, roughly `~3× lora_param_bytes_per_DP_shard`. For Qwen3-0.6B at rank 32 this is on the order of tens of MB per slot; for larger models scale accordingly. Reduce concurrent adapters or move to a host with more RAM.
168+
- **Sample returns the wrong tenant's output** — confirm `merge_lora: false` is set on the Megatron config; with merge enabled vLLM only sees the merged base.

docs/content/docs/tinker/overview.mdx

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,17 @@ SkyRL brings the Tinker API to your own hardware. By utilizing the fully Tinker
4040
| FSDP2 strategy | Supported |
4141
| Megatron strategy | Supported |
4242
| Vision models | Supported |
43-
| Multi-tenant LoRA | Not yet supported |
44-
| Multi-model sampling | Not yet supported |
45-
| Multi-model training | Not yet supported |
43+
| Multi-tenant LoRA training (Megatron + vLLM) | Supported — see [Multi-tenancy](./multi_tenancy) |
44+
| Multi-tenant LoRA sampling (Megatron + vLLM) | Supported — see [Multi-tenancy](./multi_tenancy) |
45+
| Multi-tenant LoRA on FSDP2 | Not yet supported |
46+
| Multi-tenant full-parameter fine-tuning | Not yet supported |
4647

4748
For more details, see the [Limitations & Roadmap](./limitations) page.
4849

4950
## Next Steps
5051

5152
- [Quickstart](./quickstart) - Start a SkyRL Tinker server and run your first training script
5253
- [Architecture](./architecture) - Understand how SkyRL implements the Tinker API
54+
- [Multi-tenancy](./multi_tenancy) - Run multiple LoRA tenants concurrently against one server
5355
- [Cookbook Scripts](./cookbook) - Run the official tinker-cookbook recipes on SkyRL
5456
- [Limitations & Roadmap](./limitations) - Known limitations and future plans

skyrl/backends/skyrl_train/inference_engines/vllm/vllm_engine.py

Lines changed: 20 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -308,10 +308,16 @@ async def init_weight_update_communicator(self, init_info: "WeightSyncInitInfo")
308308
args=(pickled_init_info,),
309309
)
310310

311-
async def _load_lora_from_disk(self, lora_path: str):
312-
"""Load LoRA adapters from disk using vLLM's native add_lora method."""
311+
async def _load_lora_from_disk(self, lora_path: str, lora_name: str = ""):
312+
"""Load LoRA adapters from disk using vLLM's native add_lora method.
313+
314+
When ``lora_name`` is empty (legacy single-tenant), a numeric name is
315+
generated. Multi-tenant callers pass ``lora_name`` so subsequent
316+
``model=<lora_name>`` sampling routes to the right adapter.
317+
"""
313318
lora_id = int(time.time_ns() % 0x7FFFFFFF)
314-
lora_request = LoRARequest(lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=lora_path)
319+
name = lora_name or f"{lora_id}"
320+
lora_request = LoRARequest(lora_name=name, lora_int_id=lora_id, lora_path=lora_path)
315321
result = self.llm.llm_engine.add_lora(lora_request)
316322
return result
317323

@@ -320,7 +326,7 @@ async def update_named_weights(self, request: WeightUpdateRequest):
320326

321327
# Handle LoRA disk loading request
322328
if isinstance(request, LoraLoadRequest):
323-
return await self._load_lora_from_disk(request.lora_path)
329+
return await self._load_lora_from_disk(request.lora_path, lora_name=request.lora_name)
324330

325331
if not len(request):
326332
raise ValueError("Weight update request must not be empty")
@@ -453,10 +459,16 @@ def _create_ray_prometheus_stat_loggers(self):
453459
)
454460
return None
455461

456-
async def _load_lora_from_disk(self, lora_path: str):
457-
"""Load LoRA adapters from disk using vLLM's native add_lora method."""
462+
async def _load_lora_from_disk(self, lora_path: str, lora_name: str = ""):
463+
"""Load LoRA adapters from disk using vLLM's native add_lora method.
464+
465+
When ``lora_name`` is empty (legacy single-tenant), a numeric name is
466+
generated. Multi-tenant callers pass ``lora_name`` so subsequent
467+
``model=<lora_name>`` sampling routes to the right adapter.
468+
"""
458469
lora_id = int(time.time_ns() % 0x7FFFFFFF)
459-
lora_request = LoRARequest(lora_name=f"{lora_id}", lora_int_id=lora_id, lora_path=lora_path)
470+
name = lora_name or f"{lora_id}"
471+
lora_request = LoRARequest(lora_name=name, lora_int_id=lora_id, lora_path=lora_path)
460472
result = await self.llm.add_lora(lora_request)
461473
return result
462474

@@ -539,7 +551,7 @@ async def update_named_weights(self, request: WeightUpdateRequest):
539551

540552
# Check for LoRA disk loading request
541553
if isinstance(request, LoraLoadRequest):
542-
return await self._load_lora_from_disk(request.lora_path)
554+
return await self._load_lora_from_disk(request.lora_path, lora_name=request.lora_name)
543555

544556
if not len(request):
545557
raise ValueError("Weight update request must not be empty")

0 commit comments

Comments
 (0)