@@ -47,6 +47,53 @@ def parse_datum(raw: dict[str, Any]) -> Datum:
4747class TrainingRequestsProcessor (Protocol ):
4848 store : RequestStore
4949
50+ async def process_request (self , raw_request : dict [str , Any ], model_id : str | None = None ) -> None :
51+ request_id = raw_request .get ("request_id" )
52+ token = None
53+
54+ try :
55+ op = raw_request ["op" ]
56+ request_id = raw_request ["request_id" ]
57+ resolved_model_id = model_id or raw_request .get ("model_id" ) or "default"
58+
59+ carrier = raw_request .get ("trace_context" )
60+ ctx = propagate .extract (carrier ) if carrier else None
61+ token = otel_context .attach (ctx ) if ctx else None
62+
63+ result = await self .dispatch_operation (op , raw_request .get ("payload" , {}), resolved_model_id )
64+ await self .store .set_future (request_id , result )
65+ except Exception as exc :
66+ traceback .print_exc ()
67+ if request_id is None :
68+ raise
69+ await self .store .set_future (request_id , {"type" : "RequestFailedResponse" , "error_message" : str (exc )})
70+ finally :
71+ if token :
72+ otel_context .detach (token )
73+
74+ async def dispatch_operation (self , op : str , payload : dict [str , Any ], model_id : str ) -> dict [str , Any ]:
75+ match op :
76+ case "create_model" :
77+ return await self .create_model (payload , model_id )
78+ case "create_model_from_state" :
79+ return await self .create_model_from_state (payload , model_id )
80+ case "forward_backward" :
81+ return await self .forward_backward (payload , model_id )
82+ case "optim_step" :
83+ return await self .optim_step (payload , model_id )
84+ case "sample" :
85+ return await self .sample (payload , model_id )
86+ case "save_state" :
87+ return await self .save_state (payload , model_id )
88+ case "load_weights" :
89+ return await self .load_weights (payload , model_id )
90+ case "save_weights_for_sampler" :
91+ return await self .save_weights_for_sampler (payload , model_id )
92+ case "save_weights" :
93+ return await self .save_weights (payload , model_id )
94+ case _:
95+ raise NotImplementedError (f"Training request op { op !r} is not supported" )
96+
5097 async def create_model (self , payload : dict [str , Any ], model_id : str ) -> dict [str , Any ]: ...
5198
5299 async def create_model_from_state (self , payload : dict [str , Any ], model_id : str ) -> dict [str , Any ]: ...
@@ -66,72 +113,11 @@ async def save_weights_for_sampler(self, payload: dict[str, Any], model_id: str)
66113 async def save_weights (self , payload : dict [str , Any ], model_id : str ) -> dict [str , Any ]: ...
67114
68115
69- async def dispatch_training_request (
70- processor : TrainingRequestsProcessor ,
71- raw_request : dict [str , Any ],
72- model_id : str | None = None ,
73- ) -> None :
74- request_id = raw_request .get ("request_id" )
75- token = None
76-
77- try :
78- op = raw_request ["op" ]
79- request_id = raw_request ["request_id" ]
80- resolved_model_id = model_id or raw_request .get ("model_id" ) or "default"
81-
82- carrier = raw_request .get ("trace_context" )
83- ctx = propagate .extract (carrier ) if carrier else None
84- token = otel_context .attach (ctx ) if ctx else None
85-
86- result = await dispatch_training_operation (processor , op , raw_request .get ("payload" , {}), resolved_model_id )
87- await processor .store .set_future (request_id , result )
88- except Exception as exc :
89- traceback .print_exc ()
90- if request_id is None :
91- raise
92- await processor .store .set_future (request_id , {"type" : "RequestFailedResponse" , "error_message" : str (exc )})
93- finally :
94- if token :
95- otel_context .detach (token )
96-
97-
98- async def dispatch_training_operation (
99- processor : TrainingRequestsProcessor ,
100- op : str ,
101- payload : dict [str , Any ],
102- model_id : str ,
103- ) -> dict [str , Any ]:
104- match op :
105- case "create_model" :
106- return await processor .create_model (payload , model_id )
107- case "create_model_from_state" :
108- return await processor .create_model_from_state (payload , model_id )
109- case "forward_backward" :
110- return await processor .forward_backward (payload , model_id )
111- case "optim_step" :
112- return await processor .optim_step (payload , model_id )
113- case "sample" :
114- return await processor .sample (payload , model_id )
115- case "save_state" :
116- return await processor .save_state (payload , model_id )
117- case "load_weights" :
118- return await processor .load_weights (payload , model_id )
119- case "save_weights_for_sampler" :
120- return await processor .save_weights_for_sampler (payload , model_id )
121- case "save_weights" :
122- return await processor .save_weights (payload , model_id )
123- case _:
124- raise NotImplementedError (f"Training request op { op !r} is not supported" )
125-
126-
127- class LoraTrainingRequestsProcessor :
116+ class LoraTrainingRequestsProcessor (TrainingRequestsProcessor ):
128117 def __init__ (self , store : RequestStore , worker : LoraTrainingWorker ):
129118 self .store = store
130119 self .worker = worker
131120
132- async def process_request (self , raw_request : dict [str , Any ], model_id : str | None = None ) -> None :
133- await dispatch_training_request (self , raw_request , model_id )
134-
135121 async def run (self ) -> None :
136122 print ("[WORKER] LoRA training requests processor started." )
137123
@@ -249,7 +235,7 @@ async def save_weights(self, payload: dict[str, Any], model_id: str) -> dict[str
249235 return {"status" : "ok" , "type" : "weights_saved" }
250236
251237
252- class FFTTrainingRequestsProcessor :
238+ class FFTTrainingRequestsProcessor ( TrainingRequestsProcessor ) :
253239 def __init__ (
254240 self ,
255241 store : RequestStore ,
@@ -306,9 +292,6 @@ async def run_once(self) -> None:
306292 for request in batch :
307293 await self .process_request (request , self .model_id )
308294
309- async def process_request (self , raw_request : dict [str , Any ], model_id : str | None = None ) -> None :
310- await dispatch_training_request (self , raw_request , model_id or self .model_id )
311-
312295 async def create_model (self , payload : dict [str , Any ], model_id : str ) -> dict [str , Any ]:
313296 raw_config = payload .get ("full_config" ) or {}
314297 full_config = FFTConfig (** {k : v for k , v in raw_config .items () if k in FFTConfig .model_fields })
0 commit comments