-
Notifications
You must be signed in to change notification settings - Fork 322
Expand file tree
/
Copy pathengine.py
More file actions
798 lines (668 loc) · 34.3 KB
/
engine.py
File metadata and controls
798 lines (668 loc) · 34.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
"""Background engine for processing training requests."""
import argparse
import time
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from pathlib import Path
from typing import Callable
from cloudpathlib import AnyPath
from pydantic import BaseModel
from sqlmodel import Session, create_engine, func, select, update
from skyrl.backends.utils import log_timing
from skyrl.tinker import types
from skyrl.tinker.config import EngineConfig, add_model
from skyrl.tinker.db_models import (
CheckpointDB,
CheckpointStatus,
EngineStateDB,
FutureDB,
ModelDB,
RequestStatus,
SessionDB,
enable_sqlite_wal,
)
from skyrl.utils.log import logger
def _model_not_found_error(model_id: str) -> types.ErrorResponse:
"""Log and return an ErrorResponse for a request targeting a model that isn't loaded."""
logger.info(
f"Ignoring request for model '{model_id}' — model not loaded. "
"This is most likely an outstanding request from a previous server."
)
return types.ErrorResponse(
error=f"Model {model_id} not loaded (likely stale request from previous server)",
status="failed",
)
def prepare_sample_batch(
requests: dict[str, tuple[str, types.SampleInput]],
checkpoints_base: AnyPath | None = None,
) -> types.PreparedSampleBatch:
"""Prepare batch data for sample operations.
Extracts prompts and sampling params from requests into lists
that the backend will convert to arrays.
Args:
requests: Dict mapping request_id to (model_id, request_data) tuples (pre-validated)
checkpoints_base: Base path for checkpoints (optional, needed for LoRA sampling)
Returns:
PreparedSampleBatch with all data extracted from requests
"""
all_model_inputs = []
all_sampling_params = []
all_model_ids = []
all_checkpoint_ids = []
all_checkpoint_paths = []
request_batch_slices = []
needs_prompt_logprobs = any(request_data.prompt_logprobs for (_, request_data) in requests.values())
for request_id, (model_id, request_data) in requests.items():
request_start = len(all_model_inputs)
# Expand requests for num_samples
checkpoint_path = ""
if model_id and request_data.checkpoint_id and checkpoints_base:
checkpoint_path = str(
checkpoints_base / model_id / "sampler_weights" / f"{request_data.checkpoint_id}.tar.gz"
)
for sample_idx in range(request_data.num_samples):
all_model_inputs.append(request_data.prompt)
# Derive a unique seed per sample so that num_samples > 1 produces
# diverse sequences, matching vLLM's behavior (seed + index).
sample_params = request_data.sampling_params.model_copy(
update={"seed": request_data.sampling_params.seed + sample_idx}
)
all_sampling_params.append(sample_params)
all_model_ids.append(model_id)
all_checkpoint_ids.append(request_data.checkpoint_id)
all_checkpoint_paths.append(checkpoint_path)
request_batch_slices.append(
(request_id, model_id, request_start, len(all_model_inputs), request_data.prompt_logprobs)
)
return types.PreparedSampleBatch(
all_model_inputs=all_model_inputs,
all_sampling_params=all_sampling_params,
all_model_ids=all_model_ids,
all_checkpoint_ids=all_checkpoint_ids,
all_checkpoint_paths=all_checkpoint_paths,
needs_prompt_logprobs=needs_prompt_logprobs,
request_batch_slices=request_batch_slices,
)
def prepare_model_pass_batch(
requests: dict[str, tuple[str, types.ForwardBackwardInput]],
) -> types.PreparedModelPassBatch:
"""Prepare batch data for forward/forward_backward operations.
Extracts tokens, targets, and metadata from requests into lists
that the backend will convert to arrays.
Args:
requests: Dict mapping request_id to (model_id, request_data) tuples (pre-validated)
Returns:
PreparedModelPassBatch with all data extracted from requests
"""
all_model_inputs = []
all_targets = []
all_token_weights = []
all_model_ids = []
all_sampling_logprobs = []
all_advantages = []
all_values = []
all_returns = []
all_loss_fns = []
all_loss_fn_configs = []
request_batch_slices = []
for request_id, (model_id, request_data) in requests.items():
if request_data.loss_fn not in types.SUPPORTED_LOSS_FNS:
raise ValueError(
f"Unknown loss function '{request_data.loss_fn}'. Must be one of: {sorted(types.SUPPORTED_LOSS_FNS)}"
)
request_start = len(all_model_inputs)
for item in request_data.data:
all_model_inputs.append(item.model_input)
loss_fn_inputs = item.loss_fn_inputs
all_targets.append(loss_fn_inputs.target_tokens.data)
all_token_weights.append(loss_fn_inputs.weights.data)
all_sampling_logprobs.append(loss_fn_inputs.logprobs.data)
all_advantages.append(loss_fn_inputs.advantages.data)
all_values.append(loss_fn_inputs.values.data)
all_returns.append(loss_fn_inputs.returns.data)
all_model_ids.append(model_id)
all_loss_fns.append(request_data.loss_fn)
all_loss_fn_configs.append(request_data.loss_fn_config)
request_batch_slices.append((request_id, model_id, request_start, len(all_model_inputs)))
return types.PreparedModelPassBatch(
all_model_inputs=all_model_inputs,
all_targets=all_targets,
all_token_weights=all_token_weights,
all_sampling_logprobs=all_sampling_logprobs,
all_advantages=all_advantages,
all_values=all_values,
all_returns=all_returns,
all_model_ids=all_model_ids,
all_loss_fns=all_loss_fns,
all_loss_fn_configs=all_loss_fn_configs,
request_batch_slices=request_batch_slices,
)
def get_backend_classes(backend_name: str):
"""Lazy import backends to avoid importing unused backend dependencies (e.g., JAX, Ray)."""
if backend_name == "jax":
from skyrl.backends.jax import JaxBackend, JaxBackendConfig
return JaxBackend, JaxBackendConfig
elif backend_name == "fsdp":
from skyrl.backends.skyrl_train_backend import (
FSDPBackendOverrides,
SkyRLTrainBackend,
)
return SkyRLTrainBackend, FSDPBackendOverrides
elif backend_name == "megatron":
from skyrl.backends.skyrl_train_backend import (
MegatronBackendOverrides,
SkyRLTrainBackend,
)
return SkyRLTrainBackend, MegatronBackendOverrides
else:
raise ValueError(
f"Unknown backend: {backend_name}. Available backends: jax, fsdp, megatron. "
f"Make sure the backend's dependencies are installed (e.g., pip install skyrl[jax])"
)
class TinkerEngine:
"""Background engine for processing training requests.
The engine handles:
- Database operations (futures, checkpoints)
- Request finding/scheduling
- File I/O (download/upload checkpoints)
- Validating requests against loaded models
Computation and model management are delegated to the backend.
"""
def _filter_valid_requests(
self,
requests: dict[str, tuple[str, BaseModel]],
) -> tuple[dict[str, types.ErrorResponse], dict[str, tuple[str, BaseModel]]]:
"""Filter out requests with invalid model_ids and return error results for them.
Args:
requests: Dict mapping request_id to (model_id, request_data) tuples
Returns:
Tuple of (error_results, valid_requests)
"""
results = {}
valid_requests = {}
for request_id, (model_id, request_data) in requests.items():
error = None
if model_id and not self.backend.has_model(model_id):
error = f"Model {model_id} not loaded"
elif not model_id and isinstance(request_data, types.SampleInput):
if request_data.base_model != self.config.base_model:
error = f"Engine is configured for '{self.config.base_model}' but request specified '{request_data.base_model}'"
elif request_data.checkpoint_id:
error = "checkpoint_id must be empty for base model sampling"
if error:
results[request_id] = types.ErrorResponse(error=error, status="failed")
else:
valid_requests[request_id] = (model_id, request_data)
return results, valid_requests
def __init__(
self,
config: EngineConfig,
):
"""Initialize the engine with a database connection and base model."""
self.config = config
self.db_engine = create_engine(config.database_url, echo=False)
enable_sqlite_wal(self.db_engine)
# Initialize the backend (handles model state, computation, and adapter management)
backend_class, backend_config_class = get_backend_classes(config.backend)
backend_config = backend_config_class(**config.backend_config)
self.backend = backend_class(config.base_model, backend_config)
# Backends that support async sample routing notify us when their
# inference endpoint changes; we persist it to EngineStateDB so the
# API process can forward sample requests directly. Backends stay
# DB-free; only the engine owns the connection.
if hasattr(self.backend, "set_inference_state_publisher"):
self.backend.set_inference_state_publisher(self._write_inference_state_to_db)
# Track last cleanup time for periodic stale session cleanup
self._last_cleanup_time: float = time.time()
logger.info(f"Initialized TinkerEngine with backend={type(self.backend).__name__}")
@property
def metrics(self) -> types.EngineMetrics:
"""Pass-through to backend metrics for backwards compatibility."""
return self.backend.metrics
def _write_inference_state_to_db(self, proxy_url: str | None) -> None:
"""Upsert the singleton EngineStateDB row.
Wired into the backend via set_inference_state_publisher so the API
process can resolve the engine-managed vLLM URL on the async sample
routing path. ``proxy_url=None`` clears the row (post-teardown).
"""
with Session(self.db_engine) as session:
row = session.get(EngineStateDB, 1) or EngineStateDB(singleton_id=1)
row.inference_proxy_url = proxy_url
row.updated_at = datetime.now(timezone.utc)
session.add(row)
session.commit()
@contextmanager
def _checkpoint_status_context(self, model_id: str, checkpoint_id: str, checkpoint_type: types.CheckpointType):
"""Context manager to handle checkpoint DB status updates.
Fetches the checkpoint entry, yields it, and updates its status to COMPLETED
or FAILED based on whether an exception occurred.
"""
with Session(self.db_engine) as session:
# Fail fast if API didn't create the checkpoint row first.
if session.get(CheckpointDB, (model_id, checkpoint_id, checkpoint_type)) is None:
raise ValueError(
f"Checkpoint entry not found for model '{model_id}', checkpoint '{checkpoint_id}', type '{checkpoint_type}'"
)
status = CheckpointStatus.FAILED
error_message = "checkpoint operation interrupted"
try:
# Run potentially slow checkpoint I/O without an open DB transaction.
yield
status = CheckpointStatus.COMPLETED
error_message = None
except Exception as e:
logger.exception(f"Error saving checkpoint for model {model_id}, checkpoint {checkpoint_id}: {e}")
error_message = str(e)
raise
finally:
# Persist final status in a short write transaction.
with Session(self.db_engine) as session:
result = session.exec(
update(CheckpointDB)
.where(CheckpointDB.model_id == model_id)
.where(CheckpointDB.checkpoint_id == checkpoint_id)
.where(CheckpointDB.checkpoint_type == checkpoint_type)
.values(
status=status,
error_message=error_message,
completed_at=datetime.now(timezone.utc),
)
)
if not result.rowcount:
logger.warning(
f"Checkpoint row disappeared before status update: "
f"model_id={model_id}, checkpoint_id={checkpoint_id}, checkpoint_type={checkpoint_type}"
)
session.commit()
def _find_destructive_barriers(self, session: Session) -> dict[str, int]:
"""Find the earliest pending destructive operation (optim_step/load_weights) per model.
These act as scheduling barriers: model passes before them can be batched
safely, and single requests after a blocked pass must wait.
"""
query = (
select(FutureDB.model_id, func.min(FutureDB.request_id).label("barrier_id"))
.where(
(FutureDB.request_type == types.RequestType.OPTIM_STEP)
| (FutureDB.request_type == types.RequestType.LOAD_WEIGHTS)
)
.where(FutureDB.status == RequestStatus.PENDING)
.group_by(FutureDB.model_id)
)
return dict(session.exec(query).all())
def find_batchable_model_passes(
self, session: Session, request_type: types.RequestType
) -> dict[str, tuple[str, types.ForwardBackwardInput]]:
"""Find all requests of the given type that come before any destructive update for their model.
Uses look-ahead scheduling: for each model, only returns operations
that have no optim_step or load_weights blocking them in the queue.
Args:
session: Database session
request_type: The type of request to find (e.g., FORWARD or FORWARD_BACKWARD)
Returns:
Dict mapping request_id to (model_id, request_data) tuples
"""
barriers = self._find_destructive_barriers(session)
# Get all pending operations of the requested type ordered by request_id
query = (
select(FutureDB)
.where(FutureDB.request_type == request_type)
.where(FutureDB.status == RequestStatus.PENDING)
.order_by(FutureDB.request_id)
)
ops = session.exec(query).all()
# Filter: only include ops that come before their model's barrier
batchable = [op for op in ops if op.model_id not in barriers or op.request_id < barriers[op.model_id]]
return {
str(f.request_id): (f.model_id, types.ForwardBackwardInput.model_validate(f.request_data))
for f in batchable
}
def find_batchable_sample(self, session: Session) -> dict[str, tuple[str, types.SampleInput]]:
"""Find all sample ops that can be safely batched together.
Returns sample operations ensuring that each model_id has only one checkpoint_id
to avoid loading different checkpoints for the same model in a single batch.
If sample_max_num_sequences is configured, limits to that many requests so we don't
produce partial batches in process_sample_batch. If num_samples > 1 for some requests,
this may not be perfect, but it's good until we implement continuous batching.
Args:
session: Database session
Returns:
Dict mapping request_id to (model_id, request_data) tuples
"""
sample_query = (
select(FutureDB)
.where(FutureDB.request_type == types.RequestType.SAMPLE)
.where(FutureDB.status == RequestStatus.PENDING)
.order_by(FutureDB.request_id)
)
sample_ops = session.exec(sample_query).all()
batchable = []
model_checkpoints = {} # Map from model_id to checkpoint_id of first request to that model
for op in sample_ops:
checkpoint_id = op.request_data["checkpoint_id"]
# Base model requests (empty checkpoint_id) are always compatible, otherwise only
# take only requests with one checkpoint_id for a given model_id
if checkpoint_id == "" or model_checkpoints.setdefault(op.model_id, checkpoint_id) == checkpoint_id:
batchable.append(op)
# TODO: This leaks the abstraction by accessing backend-specific config.
# We should find a better way to handle this going forward.
if self.config.backend == "jax" and self.backend.config.sample_max_num_sequences > 0:
batchable = batchable[: self.backend.config.sample_max_num_sequences]
return {str(f.request_id): (f.model_id, types.SampleInput.model_validate(f.request_data)) for f in batchable}
def find_single_requests(self, session: Session) -> dict[str, tuple[str, types.RequestType, dict]]:
"""Find all requests that need to be processed individually (not batchable).
Args:
session: Database session
Returns:
Dict mapping request_id to (model_id, request_type, request_data) tuples
"""
# Find the first blocked forward pass per model: a pending FORWARD/FORWARD_BACKWARD
# that sits behind a destructive barrier and won't be batched this iteration.
# Single requests must not jump ahead of these.
destructive_barriers = self._find_destructive_barriers(session)
blocked_pass_barriers: dict[str, int] = {}
if destructive_barriers:
pending_passes = session.exec(
select(FutureDB.model_id, FutureDB.request_id)
.where(
(FutureDB.request_type == types.RequestType.FORWARD_BACKWARD)
| (FutureDB.request_type == types.RequestType.FORWARD)
)
.where(FutureDB.status == RequestStatus.PENDING)
.where(FutureDB.model_id.in_(destructive_barriers.keys()))
.order_by(FutureDB.request_id)
).all()
for model_id, req_id in pending_passes:
if req_id >= destructive_barriers[model_id]:
blocked_pass_barriers.setdefault(model_id, req_id)
statement = (
select(FutureDB)
.where(FutureDB.status == RequestStatus.PENDING)
.where(FutureDB.request_type != types.RequestType.FORWARD_BACKWARD)
.where(FutureDB.request_type != types.RequestType.FORWARD)
.where(FutureDB.request_type != types.RequestType.SAMPLE)
.where(FutureDB.request_type != types.RequestType.EXTERNAL)
.order_by(FutureDB.request_id)
)
other_futures = session.exec(statement).all()
# Filter: only include ops that come before the first blocked pass for their model
other_futures = [
op
for op in other_futures
if op.model_id not in blocked_pass_barriers or op.request_id < blocked_pass_barriers[op.model_id]
]
return {str(f.request_id): (f.model_id, f.request_type, f.request_data) for f in other_futures}
def process_create_model(self, model_id: str, request_data: types.CreateModelInput) -> types.CreateModelOutput:
"""Create and initialize a model."""
# Create model in backend (allocates adapter_index, creates optimizer, and configures adapter)
self.backend.create_model(model_id, request_data.lora_config, model_role=request_data.model_role)
logger.info(f"Created LoRA model {model_id}")
return types.CreateModelOutput(
model_id=model_id,
base_model=self.config.base_model,
lora_config=request_data.lora_config,
)
def process_unload_model(self, model_id: str, request_data: types.UnloadModelInput) -> types.UnloadModelOutput:
"""Unload a model and free all resources."""
if not self.backend.has_model(model_id):
logger.warning(f"Ignoring unload request for model {model_id} that is not loaded.")
else:
self.backend.delete_model(model_id)
# Update model status in DB
with Session(self.db_engine) as session:
_ = session.exec(update(ModelDB).where(ModelDB.model_id == model_id).values(status="unloaded"))
session.commit()
logger.info(f"Unloaded model {model_id}")
return types.UnloadModelOutput(model_id=model_id, status="unloaded")
def cleanup_stale_sessions(self) -> int:
"""Cleanup sessions with no recent heartbeat and unload their models.
Returns:
Number of models unloaded
"""
cutoff = datetime.now(timezone.utc) - timedelta(seconds=self.config.session_timeout_sec)
unloaded_count = 0
with Session(self.db_engine) as session:
# Find stale sessions (active sessions with heartbeat older than cutoff)
stale_sessions = session.exec(
select(SessionDB).where(
SessionDB.status == "active",
SessionDB.last_heartbeat_at < cutoff,
)
).all()
if not stale_sessions:
return 0
stale_session_ids = {s.session_id for s in stale_sessions}
# Find all models for all stale sessions in one query
models_to_process = session.exec(
select(ModelDB).where(
ModelDB.session_id.in_(stale_session_ids),
ModelDB.status != "unloaded",
)
).all()
# Unload models outside DB transactions to minimize lock time.
sessions_with_failed_unloads: set[str] = set()
unloaded_model_ids: set[str] = set()
for model in models_to_process:
if self.backend.has_model(model.model_id):
try:
self.backend.delete_model(model.model_id)
unloaded_model_ids.add(model.model_id)
unloaded_count += 1
logger.info(f"Auto-unloaded stale model {model.model_id} from session {model.session_id}")
except Exception as e:
logger.error(f"Failed to auto-unload model {model.model_id}: {e}")
sessions_with_failed_unloads.add(model.session_id)
else:
# Model already missing in backend; only DB state needs cleanup.
unloaded_model_ids.add(model.model_id)
sessions_to_expire = [s.session_id for s in stale_sessions if s.session_id not in sessions_with_failed_unloads]
# Apply DB status updates in one short write transaction.
with Session(self.db_engine) as session:
if unloaded_model_ids:
_ = session.exec(
update(ModelDB).where(ModelDB.model_id.in_(unloaded_model_ids)).values(status="unloaded")
)
if sessions_to_expire:
_ = session.exec(
update(SessionDB).where(SessionDB.session_id.in_(sessions_to_expire)).values(status="expired")
)
session.commit()
for session_id in sessions_to_expire:
logger.info(f"Expired stale session {session_id}")
return unloaded_count
def process_optim_step(
self, model_id: str, request_data: types.OptimStepInput
) -> types.OptimStepOutput | types.ErrorResponse:
"""Process an optim_step request and apply accumulated gradients."""
if not self.backend.has_model(model_id):
return _model_not_found_error(model_id)
return self.backend.optim_step(model_id, request_data)
def process_forward_backward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict:
"""Run forward and backward pass on a batch of requests."""
prepared = prepare_model_pass_batch(requests)
return self.backend.forward_backward(prepared)
def process_forward(self, requests: dict[str, tuple[str, types.ForwardBackwardInput]]) -> dict:
"""Run forward-only pass on a batch of requests."""
prepared = prepare_model_pass_batch(requests)
return self.backend.forward(prepared)
def process_sample(self, requests: dict[str, tuple[str, types.SampleInput]]) -> dict:
"""Generate samples for a batch of requests."""
prepared = prepare_sample_batch(requests, self.config.checkpoints_base)
return self.backend.sample(prepared)
def process_load_weights(
self, model_id: str, request_data: types.LoadWeightsInput
) -> types.LoadWeightsOutput | types.ErrorResponse:
"""Loads a clean, trimmed training checkpoint."""
if not self.backend.has_model(model_id):
return _model_not_found_error(model_id)
checkpoint_path = (
self.config.checkpoints_base / request_data.source_model_id / f"{request_data.checkpoint_id}.tar.gz"
)
self.backend.load_checkpoint(checkpoint_path, model_id)
return types.LoadWeightsOutput(type="load_weights")
def process_save_weights(
self, model_id: str, request_data: types.SaveWeightsInput
) -> types.SaveWeightsOutput | types.ErrorResponse:
"""
Saves a clean training checkpoint by converting the trimmed NNX graph
to a pure dictionary before serialization, following official Flax docs.
"""
if not self.backend.has_model(model_id):
return _model_not_found_error(model_id)
checkpoint_id = request_data.path
output_path = self.config.checkpoints_base / model_id / f"{checkpoint_id}.tar.gz"
with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.TRAINING):
self.backend.save_checkpoint(output_path, model_id)
logger.info(f"Saved trimmed training checkpoint for model {model_id} to {output_path}")
return types.SaveWeightsOutput(
path=f"tinker://{model_id}/weights/{checkpoint_id}",
type="save_weights",
)
def process_save_weights_for_sampler(
self, model_id: str, request_data: types.SaveWeightsForSamplerInput
) -> types.SaveWeightsForSamplerOutput | types.ErrorResponse:
"""Process a save_weights_for_sampler request and save model weights."""
if not self.backend.has_model(model_id):
return _model_not_found_error(model_id)
# Make sure the user cannot store checkpoints in places like ../../<important file>
checkpoint_id = Path(request_data.path).name
output_path = self.config.checkpoints_base / model_id / "sampler_weights" / f"{checkpoint_id}.tar.gz"
# When the caller provides a sampling_session_seq_id the save is
# transient — weights only need to reach the inference engines, not
# disk. Backends can skip the expensive write in that case.
persist = request_data.sampling_session_seq_id is None
with self._checkpoint_status_context(model_id, checkpoint_id, types.CheckpointType.SAMPLER):
self.backend.save_sampler_checkpoint(output_path, model_id, persist=persist)
logger.info(f"Saved sampler checkpoint for model {model_id} to {output_path}")
# Return path=None when using sampling_session_seq_id and seq_id (SDK expects this)
if request_data.sampling_session_seq_id is not None and request_data.seq_id is not None:
output_path_str = None
else:
output_path_str = f"tinker://{model_id}/{checkpoint_id}"
return types.SaveWeightsForSamplerOutput(
path=output_path_str,
type="save_weights_for_sampler",
sampling_session_id=request_data.sampling_session_id,
)
def _complete_futures(self, results: dict[str, BaseModel]):
"""Helper method to complete multiple futures in the database.
Args:
results: Dict mapping request_id to result (Pydantic BaseModel)
"""
completed_at = datetime.now(timezone.utc)
params = [
{
"request_id": int(request_id),
"result_data": result.model_dump(),
"status": RequestStatus.FAILED if isinstance(result, types.ErrorResponse) else RequestStatus.COMPLETED,
"completed_at": completed_at,
}
for request_id, result in results.items()
]
with Session(self.db_engine) as session:
session.execute(update(FutureDB), params)
session.commit()
def process_single_request(self, request_type: types.RequestType, model_id: str, request_data: dict) -> BaseModel:
match request_type:
case types.RequestType.CREATE_MODEL:
return self.process_create_model(model_id, types.CreateModelInput.model_validate(request_data))
case types.RequestType.OPTIM_STEP:
return self.process_optim_step(model_id, types.OptimStepInput.model_validate(request_data))
case types.RequestType.SAVE_WEIGHTS_FOR_SAMPLER:
return self.process_save_weights_for_sampler(
model_id, types.SaveWeightsForSamplerInput.model_validate(request_data)
)
case types.RequestType.SAVE_WEIGHTS:
return self.process_save_weights(model_id, types.SaveWeightsInput.model_validate(request_data))
case types.RequestType.LOAD_WEIGHTS:
return self.process_load_weights(model_id, types.LoadWeightsInput.model_validate(request_data))
case types.RequestType.UNLOAD_MODEL:
return self.process_unload_model(model_id, types.UnloadModelInput.model_validate(request_data))
case _:
raise ValueError(f"Unknown request type: {request_type}")
def process_single_requests(self, requests: dict[str, tuple[str, types.RequestType, dict]]):
"""Process a collection of single (non-batchable) requests.
Args:
requests: Dict mapping request_id to (model_id, request_type, request_data) tuples
"""
if not requests:
return
results = {}
for request_id, (model_id, request_type, request_data) in requests.items():
with log_timing(f"process_single_request({request_type.value})"):
try:
result = self.process_single_request(request_type, model_id, request_data)
except Exception as e:
logger.exception(f"Error processing request {request_id}: {e}")
result = types.ErrorResponse(error=str(e), status="failed")
results[request_id] = result
self._complete_futures(results)
def process_batch_requests(
self,
requests: dict[str, tuple[str, BaseModel]],
processor: Callable[[dict[str, tuple[str, BaseModel]]], dict[str, BaseModel]],
name: str,
):
"""Process a batch of requests with error handling and future completion.
Args:
requests: Dict mapping request_id to (model_id, request_data) tuples
processor: Function that processes requests and returns results dict
name: Name for logging
"""
if not requests:
return
with log_timing(f"process_batch_requests({name}, n={len(requests)})"):
try:
error_results, valid_requests = self._filter_valid_requests(requests)
if valid_requests:
results = processor(valid_requests)
results.update(error_results)
else:
results = error_results
except Exception as e:
logger.exception(f"Error processing batch: {e}")
results = {request_id: types.ErrorResponse(error=str(e), status="failed") for request_id in requests}
self._complete_futures(results)
def process_pending_requests(self):
"""Main loop to process pending requests."""
while True:
# Query for pending requests and extract data within session context
with Session(self.db_engine) as session:
# Use look-ahead scheduling to find batchable forward_backward and forward model passes
forward_backward_requests = self.find_batchable_model_passes(
session, types.RequestType.FORWARD_BACKWARD
)
forward_requests = self.find_batchable_model_passes(session, types.RequestType.FORWARD)
# Find pending sample requests that can be batched
sample_requests = self.find_batchable_sample(session)
# Get other pending requests (non forward_backward and non sampling)
other_requests = self.find_single_requests(session)
# Process batches outside of session context
self.process_batch_requests(forward_backward_requests, self.process_forward_backward, "forward_backward")
self.process_batch_requests(forward_requests, self.process_forward, "forward")
self.process_batch_requests(sample_requests, self.process_sample, "sample")
# Process other request types individually (in the future we can also batch independent optim_steps)
self.process_single_requests(other_requests)
# Periodically cleanup stale sessions (disabled if either config is negative)
cleanup_enabled = self.config.session_cleanup_interval_sec >= 0 and self.config.session_timeout_sec >= 0
if cleanup_enabled and time.time() - self._last_cleanup_time > self.config.session_cleanup_interval_sec:
_ = self.cleanup_stale_sessions()
self._last_cleanup_time = time.time()
# Poll every 100ms
time.sleep(0.1)
def run(self):
"""Entry point to start the engine."""
logger.info("Starting background engine...")
self.process_pending_requests()
def main():
"""Entry point for the background engine."""
# Create argument parser and add Pydantic model fields
parser = argparse.ArgumentParser(description="SkyRL tinker engine for processing requests")
add_model(parser, EngineConfig)
# Parse command-line arguments
args = parser.parse_args()
# Create EngineConfig from parsed arguments
config = EngineConfig.model_validate(vars(args))
# Initialize and run the engine
TinkerEngine(config).run()
if __name__ == "__main__":
main()