Skip to content

Commit ddb87c8

Browse files
erictang000claude
andcommitted
[multi-lora] Tighten SEQ-vs-ALT test to bit-exact + add Qwen3-0.6B variant
Two related changes to the SEQ-vs-ALT regression coverage: 1. test_seq_vs_alt_per_adapter_step_isolation (tiny model): tighten from |Δ| < 1e-2 to bit-exact equality. With aca96d0's per-param- group state snapshot fix in place, both step 0 and step 1 are bit-exact across A and B on the tiny test model. Pre-fix the delta was 1.7e-4 — small but non-zero, so the bit-exact bound catches the regression even at the tiny scale. 2. New test_seq_vs_alt_qwen3_0_6b_cross_scenario: spins up a separate server fixture on Qwen/Qwen3-0.6B (~1.2 GB bf16; fits a single L4) and exercises the FULL upstream repro shape from ~/skyrl-seq-vs-alt-repro: - ALT scenario: 2 fresh adapters, A.0 / B.0 / A.1 / B.1 - SEQ scenario: 2 more fresh adapters, A.0 / A.1 / B.0 / B.1 - Asserts within-scenario A == B at every step. - Asserts cross-scenario A_ALT step N == A_SEQ step N at every step. Why we need this in addition to the tiny-model test: pre-fix on the tiny model the divergence was 1.7e-4 (within FP noise); on Qwen3-4B it was 0.45 nats. Qwen3-0.6B + cross_entropy is the smallest setup that surfaces the bug at a magnitude that's clearly real signal, while still fitting on the cheapest single-GPU box that runs Megatron LoRA SFT. Currently bit-exact on multi_lora @ aca96d0 (this branch HEAD): ALT 0: A=B=22.81096936017275 ALT 1: A=B=19.674404971301556 SEQ 0: A=B=22.81096936017275 (= ALT 0) SEQ 1: A=B=19.674404971301556 (= ALT 1) Local: 6/6 pass. Co-Authored-By: Eric Tang <erictang000@gmail.com> Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
1 parent aca96d0 commit ddb87c8

1 file changed

Lines changed: 127 additions & 18 deletions

File tree

tests/tinker/test_multi_lora_megatron.py

Lines changed: 127 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@
6565

6666

6767
@contextmanager
68-
def _api_server(port: int, backend_config: dict | None = None):
68+
def _api_server(port: int, backend_config: dict | None = None, base_model: str = BASE_MODEL):
6969
with tempfile.TemporaryDirectory() as tmp_dir:
7070
log_path = os.path.join(tmp_dir, "server.log")
7171
db_path = os.path.join(tmp_dir, "server.db")
@@ -84,7 +84,7 @@ def _api_server(port: int, backend_config: dict | None = None):
8484
"--port",
8585
str(port),
8686
"--base-model",
87-
BASE_MODEL,
87+
base_model,
8888
"--backend",
8989
"megatron",
9090
"--backend-config",
@@ -148,6 +148,26 @@ def service_client(server):
148148
return tinker.ServiceClient(base_url=f"http://0.0.0.0:{TEST_PORT}/", api_key=TINKER_API_KEY)
149149

150150

151+
# ---- Qwen3-0.6B fixtures (separate server, runs once per session) ----
152+
QWEN3_0_6B = "Qwen/Qwen3-0.6B"
153+
QWEN3_0_6B_PORT = 8012
154+
155+
156+
@pytest.fixture(scope="module")
157+
def qwen3_0_6b_server():
158+
"""Larger-model server for tests where the tiny model's loss surface is
159+
too small to surface real bugs. Qwen3-0.6B fits comfortably on a single
160+
L4 (24 GB): ~1.2 GB bf16 weights + ~1 GB LoRA optimizer state + ~1.2 GB
161+
vLLM."""
162+
with _api_server(QWEN3_0_6B_PORT, base_model=QWEN3_0_6B) as proc:
163+
yield proc
164+
165+
166+
@pytest.fixture
167+
def qwen3_0_6b_service_client(qwen3_0_6b_server):
168+
return tinker.ServiceClient(base_url=f"http://0.0.0.0:{QWEN3_0_6B_PORT}/", api_key=TINKER_API_KEY)
169+
170+
151171
def test_two_adapters_train_independently(service_client):
152172
"""Two LoRA adapters share the same base model; training one must not
153173
contaminate the other's weights.
@@ -242,25 +262,114 @@ def fb_step(c):
242262
f"[seq_vs_alt] step 1: A={a1!r} B={b1!r} |Δ|={abs(a1 - b1):.6e}"
243263
)
244264

245-
# Step 0: both adapters were pristine + saw identical data.
246-
assert abs(a0 - b0) < 1e-3, f"step 0 loss differs: A={a0!r} B={b0!r} (Δ={abs(a0 - b0):.6f})"
265+
# Step 0: both adapters were pristine + saw identical data → bit-exact.
266+
assert a0 == b0, f"step 0 loss differs: A={a0!r} B={b0!r} (Δ={abs(a0 - b0):.6e})"
247267

248268
# Step 1: both adapters had exactly one optim_step from pristine on
249-
# identical data. If the per-adapter step counter is correctly
250-
# snapshotted/restored by AdapterStore, both updates use Adam at t=2
251-
# (after the one-step priming), so their post-update states match and
252-
# their step-1 losses match.
269+
# identical data. With AdapterStore correctly snapshotting both per-
270+
# param state and per-param-group state (TE FusedAdam tracks the
271+
# bias-correction step counter at the group level — see
272+
# NovaSky-AI/SkyRL multi_lora @ aca96d0c), both updates use t=2 and
273+
# the post-update parameters are bit-identical. Bit-exact loss
274+
# follows.
253275
#
254-
# If a global step counter advanced (one for A's step 0, one for B's
255-
# step 0), B's first real update saw t=3 vs A's t=2, producing a
256-
# measurably different update.
257-
delta = abs(a1 - b1)
258-
assert delta < 1e-2, (
259-
f"step 1 loss diverges between adapters: A={a1!r} B={b1!r} (|Δ|={delta:.4f}). "
260-
f"Symmetric prediction of a shared global step counter "
261-
f"(LR scheduler position or Adam bias-correction step) advancing on every "
262-
f"optim_step instead of being held per-adapter — see "
263-
f"~/skyrl-seq-vs-alt-repro/README.md."
276+
# Pre-fix on the tiny test model this delta was 1.7e-4 (small but
277+
# non-zero — the bug WAS present, just below FP-noise on a tiny
278+
# output distribution). On Qwen3-4B + PPO it was 0.117 nats. The
279+
# bit-exact assertion catches both regressions.
280+
assert a1 == b1, (
281+
f"step 1 loss diverges between adapters: A={a1!r} B={b1!r} (|Δ|={abs(a1 - b1):.6e}). "
282+
f"Symmetric prediction of a shared global step counter (TE FusedAdam's "
283+
f"`param_groups[g]['step']`) advancing on every optim_step instead of being "
284+
f"held per-adapter — see ~/skyrl-seq-vs-alt-repro/README.md."
285+
)
286+
287+
288+
def test_seq_vs_alt_qwen3_0_6b_cross_scenario(qwen3_0_6b_service_client):
289+
"""Larger-model SEQ-vs-ALT regression test, mirroring the upstream
290+
repro at ~/skyrl-seq-vs-alt-repro (which uses Qwen3-4B + PPO).
291+
292+
The bug: TE FusedAdam tracks bias-correction `step` at the
293+
`optimizer.param_groups[g]['step']` level, not in the per-param state
294+
dict. Before NovaSky-AI/SkyRL multi_lora @ aca96d0c, AdapterStore only
295+
snapshotted per-param state, so this counter advanced globally across
296+
adapters — giving each new adapter the wrong bias correction
297+
proportional to how many optim_steps had fired since policy build.
298+
299+
Why we need this test alongside the tiny-model one: on
300+
trl-internal-testing/tiny-Qwen3ForCausalLM the divergence pre-fix was
301+
only 1.7e-4 (well within FP noise). On Qwen3-4B + PPO it was 0.45
302+
nats. Qwen3-0.6B + cross_entropy is the smallest configuration that
303+
reliably surfaces the bug at a non-trivial magnitude (~0.05-0.2 nats
304+
pre-fix) while still fitting on a single L4 GPU.
305+
306+
The two killer signals from the upstream repro:
307+
1. Within scenario: A's loss == B's loss at every step (both
308+
adapters were pristine + saw identical data, ran identical
309+
optim_step sequence per adapter — modulo when in the global
310+
schedule that happened).
311+
2. Cross scenario: A_ALT step N == A_SEQ step N (A's per-adapter
312+
trajectory must NOT depend on whether B trained in between).
313+
314+
With the AdapterStore param-group fix in place, all four runs
315+
(A_ALT, B_ALT, A_SEQ, B_SEQ) at every step land on the same loss
316+
bit-for-bit.
317+
"""
318+
sc = qwen3_0_6b_service_client
319+
320+
# ---- ALT scenario ----
321+
a_alt = sc.create_lora_training_client(base_model=QWEN3_0_6B, rank=8)
322+
b_alt = sc.create_lora_training_client(base_model=QWEN3_0_6B, rank=8)
323+
tok = a_alt.get_tokenizer()
324+
data = [_make_datum(tok, "Question: 1+1?\nAnswer:", " 2")]
325+
326+
def fb_step(c, lr=1e-4):
327+
out = c.forward_backward(data, "cross_entropy").result()
328+
loss = sum(sum(o["elementwise_loss"].data) for o in out.loss_fn_outputs)
329+
c.optim_step(tinker_types.AdamParams(learning_rate=lr)).result()
330+
return loss
331+
332+
# ALT order: A.0, B.0, A.1, B.1
333+
alt_a0 = fb_step(a_alt)
334+
alt_b0 = fb_step(b_alt)
335+
alt_a1 = fb_step(a_alt)
336+
alt_b1 = fb_step(b_alt)
337+
338+
# ---- SEQ scenario: fresh adapters, sequential order ----
339+
a_seq = sc.create_lora_training_client(base_model=QWEN3_0_6B, rank=8)
340+
b_seq = sc.create_lora_training_client(base_model=QWEN3_0_6B, rank=8)
341+
342+
# SEQ order: A.0, A.1, B.0, B.1
343+
seq_a0 = fb_step(a_seq)
344+
seq_a1 = fb_step(a_seq)
345+
seq_b0 = fb_step(b_seq)
346+
seq_b1 = fb_step(b_seq)
347+
348+
print(
349+
f"\n[seq_vs_alt_qwen3_0_6b] ALT 0: A={alt_a0!r} B={alt_b0!r}\n"
350+
f"[seq_vs_alt_qwen3_0_6b] ALT 1: A={alt_a1!r} B={alt_b1!r}\n"
351+
f"[seq_vs_alt_qwen3_0_6b] SEQ 0: A={seq_a0!r} B={seq_b0!r}\n"
352+
f"[seq_vs_alt_qwen3_0_6b] SEQ 1: A={seq_a1!r} B={seq_b1!r}\n"
353+
)
354+
355+
# Within-scenario isolation: A and B saw the same data and were both
356+
# pristine when their respective fb_step(N) ran.
357+
assert alt_a0 == alt_b0, f"ALT step 0 cross-adapter: A={alt_a0!r} B={alt_b0!r}"
358+
assert alt_a1 == alt_b1, f"ALT step 1 cross-adapter: A={alt_a1!r} B={alt_b1!r}"
359+
assert seq_a0 == seq_b0, f"SEQ step 0 cross-adapter: A={seq_a0!r} B={seq_b0!r}"
360+
assert seq_a1 == seq_b1, f"SEQ step 1 cross-adapter: A={seq_a1!r} B={seq_b1!r}"
361+
362+
# Cross-scenario A independence: A's trajectory must NOT depend on
363+
# whether B trained in between A.0 and A.1 (ALT) or didn't (SEQ).
364+
assert alt_a0 == seq_a0, (
365+
f"A_ALT step 0 ({alt_a0!r}) != A_SEQ step 0 ({seq_a0!r}); "
366+
f"both pristine + identical data — pristine snapshot must be deterministic."
367+
)
368+
assert alt_a1 == seq_a1, (
369+
f"A_ALT step 1 ({alt_a1!r}) != A_SEQ step 1 ({seq_a1!r}); "
370+
f"|Δ|={abs(alt_a1 - seq_a1):.4f} nats. Symmetric prediction of TE FusedAdam's "
371+
f"`param_groups[g]['step']` advancing globally across adapters instead of being "
372+
f"held per-adapter — see ~/skyrl-seq-vs-alt-repro/README.md."
264373
)
265374

266375

0 commit comments

Comments
 (0)