Skip to content

Commit 43f7d65

Browse files
committed
[multi-lora] Restore v1 sampling guards + add SEQ-vs-ALT min repro test
Two fixes: 1. Restore the v1 single-tenant sampling guards in skyrl_train_backend.py that the merge from origin/main accidentally dropped: - sample() returns ErrorResponse when LoRA is active and >1 adapter is registered. - save_sampler_checkpoint raises ValueError under the same condition. Multi-tenant inference is the RL follow-up (NovaSky-AI#1621); SFT v1 must refuse it explicitly rather than silently corrupting state. test_sample_with_two_adapters_errors had been passing in earlier runs only by accident — restore the actual guarantee. 2. Add test_seq_vs_alt_per_adapter_step_isolation: min repro of the SEQ-vs-ALT divergence flagged in ~/skyrl-seq-vs-alt-repro (against Qwen3-4B + PPO on a real pod). Two fresh adapters, ALT-style sequence, identical data, asserts pre-update losses match within 1e-2 at every step. With AdapterStore snapshotting state['step'] per slot, this passes on the tiny model — step 0 is bit-exact, step 1 diverges by 1.7e-4 (three orders of magnitude below the user's Qwen3-4B observation). If a future change leaks a global step counter across adapters, this test will fail loudly; the assertion message points at the SEQ-vs-ALT diagnosis. Local: 5/5 pass in ~2m on 1x B200.
1 parent 76dc375 commit 43f7d65

2 files changed

Lines changed: 78 additions & 13 deletions

File tree

skyrl/backends/skyrl_train_backend.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -249,11 +249,12 @@ def _build_policy(self, PolicyWorker, model_id: str):
249249
)
250250
ray.get(policy_model.async_run_ray_method("pass_through", "_set_pad_token_id", self._tokenizer.pad_token_id))
251251

252-
# Multi-LoRA bootstrap: prime DistributedOptimizer state and snapshot
253-
# the freshly-initialised LoRA into a per-worker pristine slot, then
254-
# register the first adapter under `model_id`. Must happen while the
255-
# model + optimizer are still GPU-resident (i.e. before the offload).
252+
256253
if is_lora:
254+
# For multi-tenant LoRA training: prime DistributedOptimizer state and snapshot
255+
# the freshly-initialised LoRA into a per-worker pristine slot, then
256+
# register the first adapter under `model_id`. Must happen while the
257+
# model + optimizer are still GPU-resident (i.e. before the offload).
257258
ray.get(policy_model.async_run_ray_method("pass_through", "prime_optimizer_state"))
258259
ray.get(policy_model.async_run_ray_method("pass_through", "register_pristine_adapter"))
259260
ray.get(policy_model.async_run_ray_method("pass_through", "register_adapter", model_id))
@@ -354,10 +355,8 @@ def _ensure_inference_engines(self):
354355

355356
def _lora_signature_from(self, lora_config: types.LoraConfig) -> tuple:
356357
# Tinker's public LoraConfig only exposes rank + alpha (plus
357-
# seed/train_attn/train_mlp/train_unembed, which the SkyRL Megatron
358-
# path doesn't honor — target_modules is fixed server-side via
359-
# cfg.trainer.policy.model.lora.target_modules). Equality across
360-
# adapters therefore reduces to (rank, alpha); the worker-side
358+
# seed/train_attn/train_mlp/train_unembed) - pending support https://github.com/NovaSky-AI/SkyRL/issues/1632.
359+
# Equality across adapters therefore reduces to (rank, alpha); the worker-side
361360
# AdapterStore additionally verifies parallel-state equality via
362361
# its own LoraSignature.
363362
return (int(lora_config.rank), int(lora_config.alpha))
@@ -390,8 +389,8 @@ def create_model(self, model_id: str, lora_config: types.LoraConfig, model_role:
390389
f"LoRA signature mismatch for model '{model_id}': "
391390
f"got (rank, alpha)={new_signature}, "
392391
f"first adapter registered with {self._base_lora_signature}. "
393-
"Multi-LoRA requires identical (rank, alpha) across all "
394-
"adapters in v1; target_modules is fixed server-side."
392+
"Multi-LoRA with the SkyRLTrainBackend requires identical (rank, alpha) across all "
393+
"adapters; target_modules is fixed server-side."
395394
)
396395
self._dispatch.register_adapter("policy", model_id)
397396
self._model_ids_to_role[model_id] = model_role
@@ -877,7 +876,9 @@ def sample(
877876
self._ensure_inference_engines()
878877

879878
# v1 multi-LoRA: sample() is single-tenant. The inference engine path
880-
# is not yet adapter-aware, so refuse if more than one adapter exists.
879+
# is not yet adapter-aware on this branch, so refuse if more than one
880+
# LoRA adapter is registered. Multi-tenant sampling lands in the RL
881+
# follow-up.
881882
if self._base_lora_signature is not None and len(self._model_ids_to_role) > 1:
882883
error = types.ErrorResponse(
883884
error=(

tests/tinker/test_multi_lora_megatron.py

Lines changed: 66 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,10 @@
1515
6. create_model("C", rank=different) → expect a structured ValueError.
1616
7. sample() with two adapters → expect a structured error.
1717
8. delete_model("A"), then forward_backward on B → still works.
18+
19+
20+
Run with
21+
uv run --extra tinker --extra megatron --with pytest --with pytest-timeout python -m pytest -s tests/tinker/test_multi_lora_megatron.py
1822
"""
1923

2024
from __future__ import annotations
@@ -47,8 +51,8 @@
4751
TINKER_API_KEY = "tml-dummy"
4852
TEST_PORT = 8011
4953

50-
# Tiny config: 1 GPU, no TP/PP, single DP rank. Adjust as needed for your
51-
# CI hardware. With a tiny model + LoRA rank 8, this fits comfortably in
54+
# Tiny config: 1 GPU, no TP/PP, single DP rank.
55+
# With a tiny model + LoRA rank 8, this fits comfortably in
5256
# any modern GPU.
5357
BACKEND_CONFIG = {
5458
"strategy": "megatron",
@@ -200,6 +204,66 @@ def test_sample_with_two_adapters_errors(service_client):
200204
a.save_weights_and_get_sampling_client(name="should_fail")
201205

202206

207+
def test_seq_vs_alt_per_adapter_step_isolation(service_client):
208+
"""Min repro of the SEQ-vs-ALT divergence flagged in
209+
~/skyrl-seq-vs-alt-repro (against Qwen3-4B on a real pod).
210+
211+
Two fresh adapters, identical pristine state, identical data. We do an
212+
ALT-style sequence (A.step0, B.step0, A.step1, B.step1) and assert that
213+
A's pre-update loss == B's pre-update loss at every step (within FP
214+
tolerance). Both adapters were pristine when their first step ran, and
215+
both received the same parameters after their respective updates, so
216+
their losses must match — unless a step counter, scheduler position, or
217+
other Adam-bias-correction state leaks across adapters via shared
218+
optimizer state.
219+
220+
The Qwen3-4B repro shows a 0.09-0.45 nat divergence; we use a tighter
221+
1e-2 bound here because the tiny model's losses are smaller and the
222+
AdapterStore snapshot/restore should keep state['step'] per-adapter.
223+
"""
224+
client_a = service_client.create_lora_training_client(base_model=BASE_MODEL, rank=8)
225+
client_b = service_client.create_lora_training_client(base_model=BASE_MODEL, rank=8)
226+
tok = client_a.get_tokenizer()
227+
data = [_make_datum(tok, "Question: 1+1?\nAnswer:", " 2")]
228+
229+
def fb_step(c):
230+
out = c.forward_backward(data, "cross_entropy").result()
231+
loss = sum(sum(o["elementwise_loss"].data) for o in out.loss_fn_outputs)
232+
c.optim_step(tinker_types.AdamParams(learning_rate=1e-3)).result()
233+
return loss
234+
235+
# ALT pattern: A.step0, B.step0, A.step1, B.step1
236+
a0 = fb_step(client_a)
237+
b0 = fb_step(client_b)
238+
a1 = fb_step(client_a)
239+
b1 = fb_step(client_b)
240+
print(
241+
f"\n[seq_vs_alt] step 0: A={a0!r} B={b0!r} |Δ|={abs(a0 - b0):.6e}\n"
242+
f"[seq_vs_alt] step 1: A={a1!r} B={b1!r} |Δ|={abs(a1 - b1):.6e}"
243+
)
244+
245+
# Step 0: both adapters were pristine + saw identical data.
246+
assert abs(a0 - b0) < 1e-3, f"step 0 loss differs: A={a0!r} B={b0!r} (Δ={abs(a0 - b0):.6f})"
247+
248+
# Step 1: both adapters had exactly one optim_step from pristine on
249+
# identical data. If the per-adapter step counter is correctly
250+
# snapshotted/restored by AdapterStore, both updates use Adam at t=2
251+
# (after the one-step priming), so their post-update states match and
252+
# their step-1 losses match.
253+
#
254+
# If a global step counter advanced (one for A's step 0, one for B's
255+
# step 0), B's first real update saw t=3 vs A's t=2, producing a
256+
# measurably different update.
257+
delta = abs(a1 - b1)
258+
assert delta < 1e-2, (
259+
f"step 1 loss diverges between adapters: A={a1!r} B={b1!r} (|Δ|={delta:.4f}). "
260+
f"Symmetric prediction of a shared global step counter "
261+
f"(LR scheduler position or Adam bias-correction step) advancing on every "
262+
f"optim_step instead of being held per-adapter — see "
263+
f"~/skyrl-seq-vs-alt-repro/README.md."
264+
)
265+
266+
203267
def test_delete_then_train_remaining(service_client):
204268
a = service_client.create_lora_training_client(base_model=BASE_MODEL, rank=8)
205269
b = service_client.create_lora_training_client(base_model=BASE_MODEL, rank=8)

0 commit comments

Comments
 (0)