1515from opentelemetry .instrumentation .fastapi import FastAPIInstrumentor
1616from opentelemetry .sdk .trace import TracerProvider
1717from opentelemetry .sdk .trace .export import BatchSpanProcessor
18- from pydantic import BaseModel
1918from store import get_store
20- from training_request_types import (
21- CreateModelFromStatePayload ,
22- CreateModelPayload ,
23- ForwardBackwardPayload ,
24- LoadWeightsPayload ,
25- OptimStepPayload ,
26- SamplePayload ,
27- SaveStatePayload ,
28- SaveWeightsForSamplerPayload ,
29- TrainingCommand ,
30- TrainingOp ,
31- )
3219from worker_launch_processor import (
3320 FFTWorkerManager ,
3421 WorkerLaunchProcessor ,
@@ -122,41 +109,40 @@ def is_sampler_weights_ref(model_id: str | None) -> bool:
122109 return len (parts ) >= 3 and parts [1 ] == "sampler_weights"
123110
124111
125- def make_training_command (
126- op : TrainingOp ,
112+ def make_training_request (
113+ op : str ,
127114 model_id : str | None ,
128- payload : BaseModel ,
115+ payload : dict ,
129116 request_id : str | None = None ,
130- ) -> TrainingCommand :
131- return TrainingCommand (
132- request_id = request_id or str (uuid .uuid4 ()),
133- op = op ,
134- model_id = model_id ,
135- payload = payload .model_dump (exclude_none = True ),
136- )
117+ ) -> dict :
118+ request = {
119+ "request_id" : request_id or str (uuid .uuid4 ()),
120+ "op" : op ,
121+ "payload" : payload ,
122+ }
123+ if model_id is not None :
124+ request ["model_id" ] = model_id
125+ return request
137126
138127
139- async def prepare_enqueue (command : TrainingCommand ) -> tuple [str , dict ]:
128+ async def enqueue (request : dict ) -> str :
129+ """Create a pending future, inject trace context, push to store. Returns req_id."""
130+ request_id = request ["request_id" ]
140131 carrier : dict = {}
141132 propagate .inject (carrier )
142- command = command .model_copy (update = {"trace_context" : carrier })
143- payload = command .model_dump (exclude_none = True )
144- await store .set_future (command .request_id , {"status" : "pending" })
145- return command .request_id , payload
146-
147-
148- async def enqueue (command : TrainingCommand ) -> str :
149- """Create a pending future, inject trace context, push to store. Returns req_id."""
150- req_id , payload = await prepare_enqueue (command )
151- await store .put_request (payload )
152- return req_id
133+ await store .set_future (request_id , {"status" : "pending" })
134+ await store .put_request ({** request , "trace_context" : carrier })
135+ return request_id
153136
154137
155- async def enqueue_worker_launch (command : TrainingCommand ) -> str :
138+ async def enqueue_worker_launch (request : dict ) -> str :
156139 """Create a pending future and push a create-model request to the worker launch queue."""
157- req_id , payload = await prepare_enqueue (command )
158- await store .put_worker_launch_request (payload )
159- return req_id
140+ request_id = request ["request_id" ]
141+ carrier : dict = {}
142+ propagate .inject (carrier )
143+ await store .set_future (request_id , {"status" : "pending" })
144+ await store .put_worker_launch_request ({** request , "trace_context" : carrier })
145+ return request_id
160146
161147
162148async def preflight_vllm () -> None :
@@ -236,10 +222,10 @@ async def lifespan(_: FastAPI):
236222 if not is_fft_enabled ():
237223 import training_requests_processor
238224
239- worker = training_requests_processor .create_training_worker ()
225+ trainer = training_requests_processor .LoraTrainingWorker ()
240226 if base_model :
241- await asyncio .to_thread (worker .load_base_model , base_model )
242- task = asyncio .create_task (training_requests_processor .run_training_requests_processor (worker ))
227+ await asyncio .to_thread (trainer .load_base_model , base_model )
228+ task = asyncio .create_task (training_requests_processor .run_training_requests_processor (trainer ))
243229 try :
244230 yield
245231 finally :
@@ -298,14 +284,14 @@ async def create_model(req: dict):
298284 if not base_model :
299285 return JSONResponse (status_code = 400 , content = {"error" : "base_model is required" })
300286 model_id = str (uuid .uuid4 ())
301- command = make_training_command (
287+ command = make_training_request (
302288 "create_model" ,
303289 model_id ,
304- CreateModelPayload (
305- base_model = base_model ,
306- lora_config = req .get ("lora_config" ) or {},
307- full_config = req .get ("full_config" ) or {},
308- ) ,
290+ {
291+ " base_model" : base_model ,
292+ " lora_config" : req .get ("lora_config" ) or {},
293+ " full_config" : req .get ("full_config" ) or {},
294+ } ,
309295 request_id = model_id ,
310296 )
311297 req_id = await enqueue_worker_launch (command ) if is_fft_enabled () else await enqueue (command )
@@ -321,13 +307,13 @@ async def create_model_from_state(req: dict):
321307 # Resolve relative names under TMP_DIR/checkpoints, leave absolute paths alone.
322308 resolved_path = state_path if os .path .isabs (state_path ) else os .path .join (TMP_DIR , "checkpoints" , state_path )
323309 model_id = str (uuid .uuid4 ())
324- command = make_training_command (
310+ command = make_training_request (
325311 "create_model_from_state" ,
326312 model_id ,
327- CreateModelFromStatePayload (
328- state_path = resolved_path ,
329- restore_optimizer = bool (req .get ("restore_optimizer" , False )),
330- ) ,
313+ {
314+ " state_path" : resolved_path ,
315+ " restore_optimizer" : bool (req .get ("restore_optimizer" , False )),
316+ } ,
331317 request_id = model_id ,
332318 )
333319 req_id = await enqueue_worker_launch (command ) if is_fft_enabled () else await enqueue (command )
@@ -376,14 +362,14 @@ async def forward_backward(req: dict):
376362 """TrainingClient.forward_backward_async()"""
377363 fwd_input = req .get ("forward_backward_input" , {})
378364 req_id = await enqueue (
379- make_training_command (
365+ make_training_request (
380366 "forward_backward" ,
381367 req .get ("model_id" ),
382- ForwardBackwardPayload (
383- data = fwd_input .get ("data" , []),
384- loss_fn = fwd_input .get ("loss_fn" , "cross_entropy" ),
385- loss_config = fwd_input .get ("loss_fn_config" , {}),
386- ) ,
368+ {
369+ " data" : fwd_input .get ("data" , []),
370+ " loss_fn" : fwd_input .get ("loss_fn" , "cross_entropy" ),
371+ " loss_config" : fwd_input .get ("loss_fn_config" , {}),
372+ } ,
387373 )
388374 )
389375 return {"request_id" : req_id }
@@ -393,10 +379,10 @@ async def forward_backward(req: dict):
393379async def optim_step (req : dict ):
394380 """TrainingClient.optim_step_async()"""
395381 req_id = await enqueue (
396- make_training_command (
382+ make_training_request (
397383 "optim_step" ,
398384 req .get ("model_id" ),
399- OptimStepPayload ( adam_params = req .get ("adam_params" , {})) ,
385+ { " adam_params" : req .get ("adam_params" , {})} ,
400386 )
401387 )
402388 return {"request_id" : req_id }
@@ -419,14 +405,14 @@ async def save_weights_for_sampler(req: dict):
419405
420406 session_id = sampler_session_id (model_id , seq_id )
421407 req_id = await enqueue (
422- make_training_command (
408+ make_training_request (
423409 "save_weights_for_sampler" ,
424410 model_id ,
425- SaveWeightsForSamplerPayload (
426- alias = alias ,
427- path = sampler_weights_path (model_id , alias ) if alias else None ,
428- sampling_session_id = session_id ,
429- ) ,
411+ {
412+ " alias" : alias ,
413+ " path" : sampler_weights_path (model_id , alias ) if alias else None ,
414+ " sampling_session_id" : session_id ,
415+ } ,
430416 )
431417 )
432418 return {"request_id" : req_id }
@@ -451,14 +437,14 @@ async def save_weights(req: dict):
451437
452438 req_id = str (uuid .uuid4 ())
453439 await enqueue (
454- make_training_command (
440+ make_training_request (
455441 "save_state" ,
456442 model_id ,
457- SaveStatePayload (
458- state_path = state_path ,
459- include_optimizer = bool (req .get ("include_optimizer" , False )),
460- kind = "weights" ,
461- ) ,
443+ {
444+ " state_path" : state_path ,
445+ " include_optimizer" : bool (req .get ("include_optimizer" , False )),
446+ " kind" : "weights" ,
447+ } ,
462448 request_id = req_id ,
463449 )
464450 )
@@ -477,13 +463,13 @@ async def load_weights(req: dict):
477463
478464 resolved_path = checkpoint_state_path (model_id , state_path )
479465 req_id = await enqueue (
480- make_training_command (
466+ make_training_request (
481467 "load_weights" ,
482468 model_id ,
483- LoadWeightsPayload (
484- state_path = resolved_path ,
485- restore_optimizer = bool (req .get ("optimizer" , False )),
486- ) ,
469+ {
470+ " state_path" : resolved_path ,
471+ " restore_optimizer" : bool (req .get ("optimizer" , False )),
472+ } ,
487473 )
488474 )
489475 return {"request_id" : req_id }
@@ -528,16 +514,16 @@ async def asample(req: dict):
528514
529515 if get_sampler_backend () == "torch" :
530516 req_id = await enqueue (
531- make_training_command (
517+ make_training_request (
532518 "sample" ,
533519 base_model_id or model_id ,
534- SamplePayload (
535- prompt_tokens = prompt ,
536- max_tokens = max_tokens ,
537- temperature = temperature ,
538- num_samples = num_samples ,
539- prompt_logprobs = bool (include_prompt_logprobs ),
540- ) ,
520+ {
521+ " prompt_tokens" : prompt ,
522+ " max_tokens" : max_tokens ,
523+ " temperature" : temperature ,
524+ " num_samples" : num_samples ,
525+ " prompt_logprobs" : bool (include_prompt_logprobs ),
526+ } ,
541527 )
542528 )
543529 return {"request_id" : req_id }
0 commit comments