feat: FFT snapshot integration#116
Conversation
|
|
||
|
|
||
| async def clock_cycle_loop(worker: TrainingWorker, model_id: str | None = None) -> None: | ||
| async def clock_cycle_loop( |
There was a problem hiding this comment.
It might make sense to have this method anchored on a class that is explicitly initialized. I think we have accumulated bunch of config such as is-fft-enabled, redis-url, snapshot-agent-lock etc. that can be initialized at the instance creation. And that will also make this more testable where we can inject worker, snapshot agent etc.
Class could be called Trainer or FFTrainer that process FFT requests for a given model-id.
WDYT ?
a02f48a to
3cff3c3
Compare
| def main() -> None: | ||
| from clock_cycle import main as clock_cycle_main | ||
| def start_request_processing_loop() -> None: | ||
| import training_requests_processor |
There was a problem hiding this comment.
is it possible to get rid of conditional import ?
| loss_fn_inputs: dict[str, TensorData] | ||
| model_input: list[int] | ||
|
|
||
| @field_validator("model_input", mode="before") |
There was a problem hiding this comment.
Pl. add a comment explaining "why"
There was a problem hiding this comment.
I initially added internal types in this PR but chose to remove it later because I felt it bloated the change, will add back in a later PR.
| }, | ||
| request_id=model_id, | ||
| ) | ||
| req_id = await enqueue_worker_launch(command) if is_fft_enabled() else await enqueue(command) |
There was a problem hiding this comment.
I wonder if we simply enqueue the "create_model" training request and backend encapsulate the logic of whether to launch a worker or not etc. keeping the API gateway decoupled from the backend.
There was a problem hiding this comment.
I agree this is the shape to target. The weirdness is coming from the fact that we have to dynamically spin up a new worker for each FFT run which is why we have a separate queue for doing so, will create an issue for this and think more about it
| }, | ||
| request_id=model_id, | ||
| ) | ||
| req_id = await enqueue_worker_launch(command) if is_fft_enabled() else await enqueue(command) |
| print("[WARNING] BASE_MODEL not provided. Cold-start penalty will apply on first request.") | ||
| is_ready = True | ||
|
|
||
| if not is_fft_enabled(): |
There was a problem hiding this comment.
do we want to expose healthcheck for fft as well or not ?
There was a problem hiding this comment.
Because we are dynamically launching workers for each job I think the meaning of a health check is different from that of a LoraWorker, will think more about this
|
The change looks good to me. I have minor nits but nothing blocking. Feel free to merge and address nits in a follow up. |
This PR does two primary things. First we make the contract for the snapshot agent only dependent on process id (see #109 for initial impl). Second this also refactors what used to be called
clock_cycle.pyintotraining_request_processor.pywhich better matches what the code actually does. The gateway is responsible for putting tinker shaped training operations on queues, and the request processor drains those operations and executes them against concrete worker (see #113 for this split). With this shape we can have two different request processors:LoraTrainingRequestsProcessorandFFTTrainingRequestsProcessor. Both agree on the contract of what operations can come off the queue but can differ in how they compose operations with their workers, namely for FFT we need to use the snapshot agent to acquire a GPU lock before executing operations.