6565
6666
6767@contextmanager
68- def _api_server (port : int , backend_config : dict | None = None ):
68+ def _api_server (port : int , backend_config : dict | None = None , base_model : str = BASE_MODEL ):
6969 with tempfile .TemporaryDirectory () as tmp_dir :
7070 log_path = os .path .join (tmp_dir , "server.log" )
7171 db_path = os .path .join (tmp_dir , "server.db" )
@@ -84,7 +84,7 @@ def _api_server(port: int, backend_config: dict | None = None):
8484 "--port" ,
8585 str (port ),
8686 "--base-model" ,
87- BASE_MODEL ,
87+ base_model ,
8888 "--backend" ,
8989 "megatron" ,
9090 "--backend-config" ,
@@ -148,6 +148,26 @@ def service_client(server):
148148 return tinker .ServiceClient (base_url = f"http://0.0.0.0:{ TEST_PORT } /" , api_key = TINKER_API_KEY )
149149
150150
151+ # ---- Qwen3-0.6B fixtures (separate server, runs once per session) ----
152+ QWEN3_0_6B = "Qwen/Qwen3-0.6B"
153+ QWEN3_0_6B_PORT = 8012
154+
155+
156+ @pytest .fixture (scope = "module" )
157+ def qwen3_0_6b_server ():
158+ """Larger-model server for tests where the tiny model's loss surface is
159+ too small to surface real bugs. Qwen3-0.6B fits comfortably on a single
160+ L4 (24 GB): ~1.2 GB bf16 weights + ~1 GB LoRA optimizer state + ~1.2 GB
161+ vLLM."""
162+ with _api_server (QWEN3_0_6B_PORT , base_model = QWEN3_0_6B ) as proc :
163+ yield proc
164+
165+
166+ @pytest .fixture
167+ def qwen3_0_6b_service_client (qwen3_0_6b_server ):
168+ return tinker .ServiceClient (base_url = f"http://0.0.0.0:{ QWEN3_0_6B_PORT } /" , api_key = TINKER_API_KEY )
169+
170+
151171def test_two_adapters_train_independently (service_client ):
152172 """Two LoRA adapters share the same base model; training one must not
153173 contaminate the other's weights.
@@ -242,25 +262,114 @@ def fb_step(c):
242262 f"[seq_vs_alt] step 1: A={ a1 !r} B={ b1 !r} |Δ|={ abs (a1 - b1 ):.6e} "
243263 )
244264
245- # Step 0: both adapters were pristine + saw identical data.
246- assert abs ( a0 - b0 ) < 1e-3 , f"step 0 loss differs: A={ a0 !r} B={ b0 !r} (Δ={ abs (a0 - b0 ):.6f } )"
265+ # Step 0: both adapters were pristine + saw identical data → bit-exact .
266+ assert a0 == b0 , f"step 0 loss differs: A={ a0 !r} B={ b0 !r} (Δ={ abs (a0 - b0 ):.6e } )"
247267
248268 # Step 1: both adapters had exactly one optim_step from pristine on
249- # identical data. If the per-adapter step counter is correctly
250- # snapshotted/restored by AdapterStore, both updates use Adam at t=2
251- # (after the one-step priming), so their post-update states match and
252- # their step-1 losses match.
269+ # identical data. With AdapterStore correctly snapshotting both per-
270+ # param state and per-param-group state (TE FusedAdam tracks the
271+ # bias-correction step counter at the group level — see
272+ # NovaSky-AI/SkyRL multi_lora @ aca96d0c), both updates use t=2 and
273+ # the post-update parameters are bit-identical. Bit-exact loss
274+ # follows.
253275 #
254- # If a global step counter advanced (one for A's step 0, one for B's
255- # step 0), B's first real update saw t=3 vs A's t=2, producing a
256- # measurably different update.
257- delta = abs (a1 - b1 )
258- assert delta < 1e-2 , (
259- f"step 1 loss diverges between adapters: A={ a1 !r} B={ b1 !r} (|Δ|={ delta :.4f} ). "
260- f"Symmetric prediction of a shared global step counter "
261- f"(LR scheduler position or Adam bias-correction step) advancing on every "
262- f"optim_step instead of being held per-adapter — see "
263- f"~/skyrl-seq-vs-alt-repro/README.md."
276+ # Pre-fix on the tiny test model this delta was 1.7e-4 (small but
277+ # non-zero — the bug WAS present, just below FP-noise on a tiny
278+ # output distribution). On Qwen3-4B + PPO it was 0.117 nats. The
279+ # bit-exact assertion catches both regressions.
280+ assert a1 == b1 , (
281+ f"step 1 loss diverges between adapters: A={ a1 !r} B={ b1 !r} (|Δ|={ abs (a1 - b1 ):.6e} ). "
282+ f"Symmetric prediction of a shared global step counter (TE FusedAdam's "
283+ f"`param_groups[g]['step']`) advancing on every optim_step instead of being "
284+ f"held per-adapter — see ~/skyrl-seq-vs-alt-repro/README.md."
285+ )
286+
287+
288+ def test_seq_vs_alt_qwen3_0_6b_cross_scenario (qwen3_0_6b_service_client ):
289+ """Larger-model SEQ-vs-ALT regression test, mirroring the upstream
290+ repro at ~/skyrl-seq-vs-alt-repro (which uses Qwen3-4B + PPO).
291+
292+ The bug: TE FusedAdam tracks bias-correction `step` at the
293+ `optimizer.param_groups[g]['step']` level, not in the per-param state
294+ dict. Before NovaSky-AI/SkyRL multi_lora @ aca96d0c, AdapterStore only
295+ snapshotted per-param state, so this counter advanced globally across
296+ adapters — giving each new adapter the wrong bias correction
297+ proportional to how many optim_steps had fired since policy build.
298+
299+ Why we need this test alongside the tiny-model one: on
300+ trl-internal-testing/tiny-Qwen3ForCausalLM the divergence pre-fix was
301+ only 1.7e-4 (well within FP noise). On Qwen3-4B + PPO it was 0.45
302+ nats. Qwen3-0.6B + cross_entropy is the smallest configuration that
303+ reliably surfaces the bug at a non-trivial magnitude (~0.05-0.2 nats
304+ pre-fix) while still fitting on a single L4 GPU.
305+
306+ The two killer signals from the upstream repro:
307+ 1. Within scenario: A's loss == B's loss at every step (both
308+ adapters were pristine + saw identical data, ran identical
309+ optim_step sequence per adapter — modulo when in the global
310+ schedule that happened).
311+ 2. Cross scenario: A_ALT step N == A_SEQ step N (A's per-adapter
312+ trajectory must NOT depend on whether B trained in between).
313+
314+ With the AdapterStore param-group fix in place, all four runs
315+ (A_ALT, B_ALT, A_SEQ, B_SEQ) at every step land on the same loss
316+ bit-for-bit.
317+ """
318+ sc = qwen3_0_6b_service_client
319+
320+ # ---- ALT scenario ----
321+ a_alt = sc .create_lora_training_client (base_model = QWEN3_0_6B , rank = 8 )
322+ b_alt = sc .create_lora_training_client (base_model = QWEN3_0_6B , rank = 8 )
323+ tok = a_alt .get_tokenizer ()
324+ data = [_make_datum (tok , "Question: 1+1?\n Answer:" , " 2" )]
325+
326+ def fb_step (c , lr = 1e-4 ):
327+ out = c .forward_backward (data , "cross_entropy" ).result ()
328+ loss = sum (sum (o ["elementwise_loss" ].data ) for o in out .loss_fn_outputs )
329+ c .optim_step (tinker_types .AdamParams (learning_rate = lr )).result ()
330+ return loss
331+
332+ # ALT order: A.0, B.0, A.1, B.1
333+ alt_a0 = fb_step (a_alt )
334+ alt_b0 = fb_step (b_alt )
335+ alt_a1 = fb_step (a_alt )
336+ alt_b1 = fb_step (b_alt )
337+
338+ # ---- SEQ scenario: fresh adapters, sequential order ----
339+ a_seq = sc .create_lora_training_client (base_model = QWEN3_0_6B , rank = 8 )
340+ b_seq = sc .create_lora_training_client (base_model = QWEN3_0_6B , rank = 8 )
341+
342+ # SEQ order: A.0, A.1, B.0, B.1
343+ seq_a0 = fb_step (a_seq )
344+ seq_a1 = fb_step (a_seq )
345+ seq_b0 = fb_step (b_seq )
346+ seq_b1 = fb_step (b_seq )
347+
348+ print (
349+ f"\n [seq_vs_alt_qwen3_0_6b] ALT 0: A={ alt_a0 !r} B={ alt_b0 !r} \n "
350+ f"[seq_vs_alt_qwen3_0_6b] ALT 1: A={ alt_a1 !r} B={ alt_b1 !r} \n "
351+ f"[seq_vs_alt_qwen3_0_6b] SEQ 0: A={ seq_a0 !r} B={ seq_b0 !r} \n "
352+ f"[seq_vs_alt_qwen3_0_6b] SEQ 1: A={ seq_a1 !r} B={ seq_b1 !r} \n "
353+ )
354+
355+ # Within-scenario isolation: A and B saw the same data and were both
356+ # pristine when their respective fb_step(N) ran.
357+ assert alt_a0 == alt_b0 , f"ALT step 0 cross-adapter: A={ alt_a0 !r} B={ alt_b0 !r} "
358+ assert alt_a1 == alt_b1 , f"ALT step 1 cross-adapter: A={ alt_a1 !r} B={ alt_b1 !r} "
359+ assert seq_a0 == seq_b0 , f"SEQ step 0 cross-adapter: A={ seq_a0 !r} B={ seq_b0 !r} "
360+ assert seq_a1 == seq_b1 , f"SEQ step 1 cross-adapter: A={ seq_a1 !r} B={ seq_b1 !r} "
361+
362+ # Cross-scenario A independence: A's trajectory must NOT depend on
363+ # whether B trained in between A.0 and A.1 (ALT) or didn't (SEQ).
364+ assert alt_a0 == seq_a0 , (
365+ f"A_ALT step 0 ({ alt_a0 !r} ) != A_SEQ step 0 ({ seq_a0 !r} ); "
366+ f"both pristine + identical data — pristine snapshot must be deterministic."
367+ )
368+ assert alt_a1 == seq_a1 , (
369+ f"A_ALT step 1 ({ alt_a1 !r} ) != A_SEQ step 1 ({ seq_a1 !r} ); "
370+ f"|Δ|={ abs (alt_a1 - seq_a1 ):.4f} nats. Symmetric prediction of TE FusedAdam's "
371+ f"`param_groups[g]['step']` advancing globally across adapters instead of being "
372+ f"held per-adapter — see ~/skyrl-seq-vs-alt-repro/README.md."
264373 )
265374
266375
0 commit comments