Skip to content

Commit a02f48a

Browse files
committed
Simplify FFT request processor boundary
1 parent 96ddef9 commit a02f48a

7 files changed

Lines changed: 280 additions & 392 deletions

src/server/gateway.py

Lines changed: 72 additions & 86 deletions
Original file line numberDiff line numberDiff line change
@@ -15,20 +15,7 @@
1515
from opentelemetry.instrumentation.fastapi import FastAPIInstrumentor
1616
from opentelemetry.sdk.trace import TracerProvider
1717
from opentelemetry.sdk.trace.export import BatchSpanProcessor
18-
from pydantic import BaseModel
1918
from store import get_store
20-
from training_request_types import (
21-
CreateModelFromStatePayload,
22-
CreateModelPayload,
23-
ForwardBackwardPayload,
24-
LoadWeightsPayload,
25-
OptimStepPayload,
26-
SamplePayload,
27-
SaveStatePayload,
28-
SaveWeightsForSamplerPayload,
29-
TrainingCommand,
30-
TrainingOp,
31-
)
3219
from worker_launch_processor import (
3320
FFTWorkerManager,
3421
WorkerLaunchProcessor,
@@ -122,41 +109,40 @@ def is_sampler_weights_ref(model_id: str | None) -> bool:
122109
return len(parts) >= 3 and parts[1] == "sampler_weights"
123110

124111

125-
def make_training_command(
126-
op: TrainingOp,
112+
def make_training_request(
113+
op: str,
127114
model_id: str | None,
128-
payload: BaseModel,
115+
payload: dict,
129116
request_id: str | None = None,
130-
) -> TrainingCommand:
131-
return TrainingCommand(
132-
request_id=request_id or str(uuid.uuid4()),
133-
op=op,
134-
model_id=model_id,
135-
payload=payload.model_dump(exclude_none=True),
136-
)
117+
) -> dict:
118+
request = {
119+
"request_id": request_id or str(uuid.uuid4()),
120+
"op": op,
121+
"payload": payload,
122+
}
123+
if model_id is not None:
124+
request["model_id"] = model_id
125+
return request
137126

138127

139-
async def prepare_enqueue(command: TrainingCommand) -> tuple[str, dict]:
128+
async def enqueue(request: dict) -> str:
129+
"""Create a pending future, inject trace context, push to store. Returns req_id."""
130+
request_id = request["request_id"]
140131
carrier: dict = {}
141132
propagate.inject(carrier)
142-
command = command.model_copy(update={"trace_context": carrier})
143-
payload = command.model_dump(exclude_none=True)
144-
await store.set_future(command.request_id, {"status": "pending"})
145-
return command.request_id, payload
146-
147-
148-
async def enqueue(command: TrainingCommand) -> str:
149-
"""Create a pending future, inject trace context, push to store. Returns req_id."""
150-
req_id, payload = await prepare_enqueue(command)
151-
await store.put_request(payload)
152-
return req_id
133+
await store.set_future(request_id, {"status": "pending"})
134+
await store.put_request({**request, "trace_context": carrier})
135+
return request_id
153136

154137

155-
async def enqueue_worker_launch(command: TrainingCommand) -> str:
138+
async def enqueue_worker_launch(request: dict) -> str:
156139
"""Create a pending future and push a create-model request to the worker launch queue."""
157-
req_id, payload = await prepare_enqueue(command)
158-
await store.put_worker_launch_request(payload)
159-
return req_id
140+
request_id = request["request_id"]
141+
carrier: dict = {}
142+
propagate.inject(carrier)
143+
await store.set_future(request_id, {"status": "pending"})
144+
await store.put_worker_launch_request({**request, "trace_context": carrier})
145+
return request_id
160146

161147

162148
async def preflight_vllm() -> None:
@@ -236,10 +222,10 @@ async def lifespan(_: FastAPI):
236222
if not is_fft_enabled():
237223
import training_requests_processor
238224

239-
worker = training_requests_processor.create_training_worker()
225+
trainer = training_requests_processor.LoraTrainingWorker()
240226
if base_model:
241-
await asyncio.to_thread(worker.load_base_model, base_model)
242-
task = asyncio.create_task(training_requests_processor.run_training_requests_processor(worker))
227+
await asyncio.to_thread(trainer.load_base_model, base_model)
228+
task = asyncio.create_task(training_requests_processor.run_training_requests_processor(trainer))
243229
try:
244230
yield
245231
finally:
@@ -298,14 +284,14 @@ async def create_model(req: dict):
298284
if not base_model:
299285
return JSONResponse(status_code=400, content={"error": "base_model is required"})
300286
model_id = str(uuid.uuid4())
301-
command = make_training_command(
287+
command = make_training_request(
302288
"create_model",
303289
model_id,
304-
CreateModelPayload(
305-
base_model=base_model,
306-
lora_config=req.get("lora_config") or {},
307-
full_config=req.get("full_config") or {},
308-
),
290+
{
291+
"base_model": base_model,
292+
"lora_config": req.get("lora_config") or {},
293+
"full_config": req.get("full_config") or {},
294+
},
309295
request_id=model_id,
310296
)
311297
req_id = await enqueue_worker_launch(command) if is_fft_enabled() else await enqueue(command)
@@ -321,13 +307,13 @@ async def create_model_from_state(req: dict):
321307
# Resolve relative names under TMP_DIR/checkpoints, leave absolute paths alone.
322308
resolved_path = state_path if os.path.isabs(state_path) else os.path.join(TMP_DIR, "checkpoints", state_path)
323309
model_id = str(uuid.uuid4())
324-
command = make_training_command(
310+
command = make_training_request(
325311
"create_model_from_state",
326312
model_id,
327-
CreateModelFromStatePayload(
328-
state_path=resolved_path,
329-
restore_optimizer=bool(req.get("restore_optimizer", False)),
330-
),
313+
{
314+
"state_path": resolved_path,
315+
"restore_optimizer": bool(req.get("restore_optimizer", False)),
316+
},
331317
request_id=model_id,
332318
)
333319
req_id = await enqueue_worker_launch(command) if is_fft_enabled() else await enqueue(command)
@@ -376,14 +362,14 @@ async def forward_backward(req: dict):
376362
"""TrainingClient.forward_backward_async()"""
377363
fwd_input = req.get("forward_backward_input", {})
378364
req_id = await enqueue(
379-
make_training_command(
365+
make_training_request(
380366
"forward_backward",
381367
req.get("model_id"),
382-
ForwardBackwardPayload(
383-
data=fwd_input.get("data", []),
384-
loss_fn=fwd_input.get("loss_fn", "cross_entropy"),
385-
loss_config=fwd_input.get("loss_fn_config", {}),
386-
),
368+
{
369+
"data": fwd_input.get("data", []),
370+
"loss_fn": fwd_input.get("loss_fn", "cross_entropy"),
371+
"loss_config": fwd_input.get("loss_fn_config", {}),
372+
},
387373
)
388374
)
389375
return {"request_id": req_id}
@@ -393,10 +379,10 @@ async def forward_backward(req: dict):
393379
async def optim_step(req: dict):
394380
"""TrainingClient.optim_step_async()"""
395381
req_id = await enqueue(
396-
make_training_command(
382+
make_training_request(
397383
"optim_step",
398384
req.get("model_id"),
399-
OptimStepPayload(adam_params=req.get("adam_params", {})),
385+
{"adam_params": req.get("adam_params", {})},
400386
)
401387
)
402388
return {"request_id": req_id}
@@ -419,14 +405,14 @@ async def save_weights_for_sampler(req: dict):
419405

420406
session_id = sampler_session_id(model_id, seq_id)
421407
req_id = await enqueue(
422-
make_training_command(
408+
make_training_request(
423409
"save_weights_for_sampler",
424410
model_id,
425-
SaveWeightsForSamplerPayload(
426-
alias=alias,
427-
path=sampler_weights_path(model_id, alias) if alias else None,
428-
sampling_session_id=session_id,
429-
),
411+
{
412+
"alias": alias,
413+
"path": sampler_weights_path(model_id, alias) if alias else None,
414+
"sampling_session_id": session_id,
415+
},
430416
)
431417
)
432418
return {"request_id": req_id}
@@ -451,14 +437,14 @@ async def save_weights(req: dict):
451437

452438
req_id = str(uuid.uuid4())
453439
await enqueue(
454-
make_training_command(
440+
make_training_request(
455441
"save_state",
456442
model_id,
457-
SaveStatePayload(
458-
state_path=state_path,
459-
include_optimizer=bool(req.get("include_optimizer", False)),
460-
kind="weights",
461-
),
443+
{
444+
"state_path": state_path,
445+
"include_optimizer": bool(req.get("include_optimizer", False)),
446+
"kind": "weights",
447+
},
462448
request_id=req_id,
463449
)
464450
)
@@ -477,13 +463,13 @@ async def load_weights(req: dict):
477463

478464
resolved_path = checkpoint_state_path(model_id, state_path)
479465
req_id = await enqueue(
480-
make_training_command(
466+
make_training_request(
481467
"load_weights",
482468
model_id,
483-
LoadWeightsPayload(
484-
state_path=resolved_path,
485-
restore_optimizer=bool(req.get("optimizer", False)),
486-
),
469+
{
470+
"state_path": resolved_path,
471+
"restore_optimizer": bool(req.get("optimizer", False)),
472+
},
487473
)
488474
)
489475
return {"request_id": req_id}
@@ -528,16 +514,16 @@ async def asample(req: dict):
528514

529515
if get_sampler_backend() == "torch":
530516
req_id = await enqueue(
531-
make_training_command(
517+
make_training_request(
532518
"sample",
533519
base_model_id or model_id,
534-
SamplePayload(
535-
prompt_tokens=prompt,
536-
max_tokens=max_tokens,
537-
temperature=temperature,
538-
num_samples=num_samples,
539-
prompt_logprobs=bool(include_prompt_logprobs),
540-
),
520+
{
521+
"prompt_tokens": prompt,
522+
"max_tokens": max_tokens,
523+
"temperature": temperature,
524+
"num_samples": num_samples,
525+
"prompt_logprobs": bool(include_prompt_logprobs),
526+
},
541527
)
542528
)
543529
return {"request_id": req_id}

src/server/training/trainer_worker.py

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
from typing import Any
66

77
import torch
8-
from pydantic import BaseModel, field_validator, model_validator
8+
from pydantic import BaseModel
99
from transformers import PreTrainedModel, PreTrainedTokenizerBase
1010

1111
from training import losses
@@ -14,32 +14,11 @@
1414
class TensorData(BaseModel):
1515
data: list[int] | list[float]
1616

17-
@model_validator(mode="before")
18-
@classmethod
19-
def accept_raw_sequence(cls, value):
20-
if isinstance(value, list):
21-
return {"data": value}
22-
return value
23-
2417

2518
class Datum(BaseModel):
2619
loss_fn_inputs: dict[str, TensorData]
2720
model_input: list[int]
2821

29-
@field_validator("model_input", mode="before")
30-
@classmethod
31-
def flatten_model_input(cls, value):
32-
if not isinstance(value, dict):
33-
return value
34-
35-
tokens: list[int] = []
36-
for chunk in value.get("chunks", []):
37-
if isinstance(chunk, dict):
38-
tokens.extend(chunk.get("tokens", []))
39-
else:
40-
tokens.extend(getattr(chunk, "tokens", []))
41-
return tokens
42-
4322

4423
class BaseTrainerWorker:
4524
def __init__(self):

0 commit comments

Comments
 (0)