@@ -72,7 +72,7 @@ def _api_server(port: int, backend_config: dict | None = None):
7272 "--extra" ,
7373 "tinker" ,
7474 "--extra" ,
75- "skyrl_train " ,
75+ "megatron " ,
7676 "-m" ,
7777 "skyrl.tinker.api" ,
7878 "--host" ,
@@ -82,7 +82,7 @@ def _api_server(port: int, backend_config: dict | None = None):
8282 "--base-model" ,
8383 BASE_MODEL ,
8484 "--backend" ,
85- "skyrl_train " ,
85+ "megatron " ,
8686 "--backend-config" ,
8787 json .dumps (cfg ),
8888 "--database-url" ,
@@ -115,7 +115,7 @@ def _server_is_up(port: int) -> bool:
115115 import urllib .request
116116
117117 try :
118- urllib .request .urlopen (f"http://0.0.0.0:{ port } /api/v1/server_capabilities " , timeout = 2 ).read ()
118+ urllib .request .urlopen (f"http://0.0.0.0:{ port } /api/v1/healthz " , timeout = 2 ).read ()
119119 return True
120120 except (urllib .error .URLError , urllib .error .HTTPError , ConnectionError , TimeoutError ):
121121 return False
@@ -146,32 +146,33 @@ def service_client(server):
146146
147147def test_two_adapters_train_independently (service_client ):
148148 """Two LoRA adapters share the same base model; training one must not
149- contaminate the other's weights."""
150- client_a = service_client .create_lora_training_client (base_model = BASE_MODEL , lora_rank = 8 )
151- client_b = service_client .create_lora_training_client (base_model = BASE_MODEL , lora_rank = 8 )
149+ contaminate the other's weights.
150+
151+ SFT-scope test (multi_lora branch): we don't push weights to vLLM here
152+ because save_weights_for_sampler is deliberately gated to single-adapter
153+ in v1. We verify isolation by asserting A's loss continues to improve
154+ after we've swapped to B and back — that's only possible if A's
155+ optimizer state survived the swap-out + B-training + swap-in cycle.
156+ """
157+ client_a = service_client .create_lora_training_client (base_model = BASE_MODEL , rank = 8 )
158+ client_b = service_client .create_lora_training_client (base_model = BASE_MODEL , rank = 8 )
152159 tok = client_a .get_tokenizer ()
153160
154161 data = [_make_datum (tok , "Question: 1+1?\n Answer:" , " 2" )]
155162
156- # Train A twice
163+ # Train A twice (priming + one real step)
157164 for _ in range (2 ):
158165 client_a .forward_backward (data , "cross_entropy" ).result ()
159166 client_a .optim_step (tinker_types .AdamParams (learning_rate = 1e-3 )).result ()
160- a_path_after_training = client_a .save_weights_for_sampler (name = "a_trained" ).result ().path
161167
162- # Train B once with a different LR
168+ # Train B once with a different LR — this swaps the live adapter to B.
163169 client_b .forward_backward (data , "cross_entropy" ).result ()
164170 client_b .optim_step (tinker_types .AdamParams (learning_rate = 1e-4 )).result ()
165- b_path = client_b .save_weights_for_sampler (name = "b_trained" ).result ().path
166171
167- # Switch back to A and check its state survived
168- a_path_after_swap = client_a .save_weights_for_sampler (name = "a_after_swap" ).result ().path
169-
170- # The two A snapshots must be byte-identical: A's state should not have
171- # been changed by training B in between.
172- assert a_path_after_training and a_path_after_swap and b_path
173-
174- # A continued training must converge from A's state, not from pristine.
172+ # Switch back to A. If A's optimizer/grad state was wiped by the swap,
173+ # the next step won't produce a sane gradient direction and loss won't
174+ # improve. Single-step convergence on a fixed micro-batch is reliable
175+ # for a tiny model + nontrivial LR.
175176 pre_loss = client_a .forward_backward (data , "cross_entropy" ).result ()
176177 client_a .optim_step (tinker_types .AdamParams (learning_rate = 1e-3 )).result ()
177178 post_loss = client_a .forward_backward (data , "cross_entropy" ).result ()
@@ -184,24 +185,24 @@ def test_two_adapters_train_independently(service_client):
184185
185186
186187def test_rank_mismatch_rejected (service_client ):
187- service_client .create_lora_training_client (base_model = BASE_MODEL , lora_rank = 8 )
188+ service_client .create_lora_training_client (base_model = BASE_MODEL , rank = 8 )
188189 with pytest .raises (Exception ) as exc :
189- service_client .create_lora_training_client (base_model = BASE_MODEL , lora_rank = 16 )
190+ service_client .create_lora_training_client (base_model = BASE_MODEL , rank = 16 )
190191 assert "signature mismatch" in str (exc .value ).lower () or "rank" in str (exc .value ).lower ()
191192
192193
193194def test_sample_with_two_adapters_errors (service_client ):
194- a = service_client .create_lora_training_client (base_model = BASE_MODEL , lora_rank = 8 )
195- service_client .create_lora_training_client (base_model = BASE_MODEL , lora_rank = 8 )
195+ a = service_client .create_lora_training_client (base_model = BASE_MODEL , rank = 8 )
196+ service_client .create_lora_training_client (base_model = BASE_MODEL , rank = 8 )
196197 with pytest .raises (Exception ):
197198 # save_weights_and_get_sampling_client routes through
198199 # save_sampler_checkpoint, which v1 refuses with >1 adapter.
199200 a .save_weights_and_get_sampling_client (name = "should_fail" )
200201
201202
202203def test_delete_then_train_remaining (service_client ):
203- a = service_client .create_lora_training_client (base_model = BASE_MODEL , lora_rank = 8 )
204- b = service_client .create_lora_training_client (base_model = BASE_MODEL , lora_rank = 8 )
204+ a = service_client .create_lora_training_client (base_model = BASE_MODEL , rank = 8 )
205+ b = service_client .create_lora_training_client (base_model = BASE_MODEL , rank = 8 )
205206 tok = a .get_tokenizer ()
206207 data = [_make_datum (tok , "Q?" , " a" )]
207208
0 commit comments