Skip to content

Commit c18703c

Browse files
committed
small dispatch cleanup
1 parent 3cff3c3 commit c18703c

1 file changed

Lines changed: 49 additions & 66 deletions

File tree

src/server/training_requests_processor.py

Lines changed: 49 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,53 @@ def parse_datum(raw: dict[str, Any]) -> Datum:
4747
class TrainingRequestsProcessor(Protocol):
4848
store: RequestStore
4949

50+
async def process_request(self, raw_request: dict[str, Any], model_id: str | None = None) -> None:
51+
request_id = raw_request.get("request_id")
52+
token = None
53+
54+
try:
55+
op = raw_request["op"]
56+
request_id = raw_request["request_id"]
57+
resolved_model_id = model_id or raw_request.get("model_id") or "default"
58+
59+
carrier = raw_request.get("trace_context")
60+
ctx = propagate.extract(carrier) if carrier else None
61+
token = otel_context.attach(ctx) if ctx else None
62+
63+
result = await self.dispatch_operation(op, raw_request.get("payload", {}), resolved_model_id)
64+
await self.store.set_future(request_id, result)
65+
except Exception as exc:
66+
traceback.print_exc()
67+
if request_id is None:
68+
raise
69+
await self.store.set_future(request_id, {"type": "RequestFailedResponse", "error_message": str(exc)})
70+
finally:
71+
if token:
72+
otel_context.detach(token)
73+
74+
async def dispatch_operation(self, op: str, payload: dict[str, Any], model_id: str) -> dict[str, Any]:
75+
match op:
76+
case "create_model":
77+
return await self.create_model(payload, model_id)
78+
case "create_model_from_state":
79+
return await self.create_model_from_state(payload, model_id)
80+
case "forward_backward":
81+
return await self.forward_backward(payload, model_id)
82+
case "optim_step":
83+
return await self.optim_step(payload, model_id)
84+
case "sample":
85+
return await self.sample(payload, model_id)
86+
case "save_state":
87+
return await self.save_state(payload, model_id)
88+
case "load_weights":
89+
return await self.load_weights(payload, model_id)
90+
case "save_weights_for_sampler":
91+
return await self.save_weights_for_sampler(payload, model_id)
92+
case "save_weights":
93+
return await self.save_weights(payload, model_id)
94+
case _:
95+
raise NotImplementedError(f"Training request op {op!r} is not supported")
96+
5097
async def create_model(self, payload: dict[str, Any], model_id: str) -> dict[str, Any]: ...
5198

5299
async def create_model_from_state(self, payload: dict[str, Any], model_id: str) -> dict[str, Any]: ...
@@ -66,72 +113,11 @@ async def save_weights_for_sampler(self, payload: dict[str, Any], model_id: str)
66113
async def save_weights(self, payload: dict[str, Any], model_id: str) -> dict[str, Any]: ...
67114

68115

69-
async def dispatch_training_request(
70-
processor: TrainingRequestsProcessor,
71-
raw_request: dict[str, Any],
72-
model_id: str | None = None,
73-
) -> None:
74-
request_id = raw_request.get("request_id")
75-
token = None
76-
77-
try:
78-
op = raw_request["op"]
79-
request_id = raw_request["request_id"]
80-
resolved_model_id = model_id or raw_request.get("model_id") or "default"
81-
82-
carrier = raw_request.get("trace_context")
83-
ctx = propagate.extract(carrier) if carrier else None
84-
token = otel_context.attach(ctx) if ctx else None
85-
86-
result = await dispatch_training_operation(processor, op, raw_request.get("payload", {}), resolved_model_id)
87-
await processor.store.set_future(request_id, result)
88-
except Exception as exc:
89-
traceback.print_exc()
90-
if request_id is None:
91-
raise
92-
await processor.store.set_future(request_id, {"type": "RequestFailedResponse", "error_message": str(exc)})
93-
finally:
94-
if token:
95-
otel_context.detach(token)
96-
97-
98-
async def dispatch_training_operation(
99-
processor: TrainingRequestsProcessor,
100-
op: str,
101-
payload: dict[str, Any],
102-
model_id: str,
103-
) -> dict[str, Any]:
104-
match op:
105-
case "create_model":
106-
return await processor.create_model(payload, model_id)
107-
case "create_model_from_state":
108-
return await processor.create_model_from_state(payload, model_id)
109-
case "forward_backward":
110-
return await processor.forward_backward(payload, model_id)
111-
case "optim_step":
112-
return await processor.optim_step(payload, model_id)
113-
case "sample":
114-
return await processor.sample(payload, model_id)
115-
case "save_state":
116-
return await processor.save_state(payload, model_id)
117-
case "load_weights":
118-
return await processor.load_weights(payload, model_id)
119-
case "save_weights_for_sampler":
120-
return await processor.save_weights_for_sampler(payload, model_id)
121-
case "save_weights":
122-
return await processor.save_weights(payload, model_id)
123-
case _:
124-
raise NotImplementedError(f"Training request op {op!r} is not supported")
125-
126-
127-
class LoraTrainingRequestsProcessor:
116+
class LoraTrainingRequestsProcessor(TrainingRequestsProcessor):
128117
def __init__(self, store: RequestStore, worker: LoraTrainingWorker):
129118
self.store = store
130119
self.worker = worker
131120

132-
async def process_request(self, raw_request: dict[str, Any], model_id: str | None = None) -> None:
133-
await dispatch_training_request(self, raw_request, model_id)
134-
135121
async def run(self) -> None:
136122
print("[WORKER] LoRA training requests processor started.")
137123

@@ -249,7 +235,7 @@ async def save_weights(self, payload: dict[str, Any], model_id: str) -> dict[str
249235
return {"status": "ok", "type": "weights_saved"}
250236

251237

252-
class FFTTrainingRequestsProcessor:
238+
class FFTTrainingRequestsProcessor(TrainingRequestsProcessor):
253239
def __init__(
254240
self,
255241
store: RequestStore,
@@ -306,9 +292,6 @@ async def run_once(self) -> None:
306292
for request in batch:
307293
await self.process_request(request, self.model_id)
308294

309-
async def process_request(self, raw_request: dict[str, Any], model_id: str | None = None) -> None:
310-
await dispatch_training_request(self, raw_request, model_id or self.model_id)
311-
312295
async def create_model(self, payload: dict[str, Any], model_id: str) -> dict[str, Any]:
313296
raw_config = payload.get("full_config") or {}
314297
full_config = FFTConfig(**{k: v for k, v in raw_config.items() if k in FFTConfig.model_fields})

0 commit comments

Comments
 (0)