-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathbase_backend.py
More file actions
120 lines (93 loc) · 3.7 KB
/
base_backend.py
File metadata and controls
120 lines (93 loc) · 3.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
import os
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Optional
from tinker import types
from ..checkpoints import CheckpointRecord
from ..config import ModelConfig
class BaseBackend(ABC):
"""Base class for all backends."""
def __init__(self, config: ModelConfig) -> None:
self.base_model = config.model_name
self.config = config
@abstractmethod
async def async_init(self) -> None:
"""Asynchronous initialization if needed."""
class BaseSamplingBackend(BaseBackend):
"""Abstract sampling backend."""
@abstractmethod
async def sample(
self,
prompt: types.ModelInput,
num_samples: int,
sampling_params: types.SamplingParams,
include_prompt_logprobs: bool = False,
topk_prompt_logprobs: int = 0,
lora_id: Optional[str] = None,
) -> types.SampleResponse:
"""Abstract method for sampling."""
@abstractmethod
async def add_adapter(self, lora_id: str, adapter_path: Path) -> None:
"""Add LoRA adapter to the backend."""
@abstractmethod
async def remove_adapter(self, lora_id: str) -> None:
"""Remove LoRA adapter from the backend."""
def get_openai_api_url(self) -> Optional[str]:
"""Return the vLLM OpenAI API base URL, or None if not available."""
return None
@classmethod
def create_backend(cls, config: ModelConfig) -> "BaseSamplingBackend":
"""Factory method to create a sampling backend instance."""
if os.getenv("TUFT_CPU_TEST", "0") == "1":
from ..backends.sampling_backend import DummySamplingBackend
return DummySamplingBackend(config)
else:
from ..backends.sampling_backend import VLLMSamplingBackend
return VLLMSamplingBackend(config)
class BaseTrainingBackend(BaseBackend):
"""Abstract training backend."""
@abstractmethod
async def forward(
self,
data: list[types.Datum],
lora_id: str,
loss_fn: types.LossFnType,
loss_fn_config: dict[str, float] | None,
backward: bool = False,
) -> types.ForwardBackwardOutput:
"""Abstract method for forward pass."""
@abstractmethod
async def create_adapter(self, lora_id: str, lora_config: types.LoraConfig) -> None:
"""Abstract method for creating LoRA adapter."""
@abstractmethod
async def remove_adapter(self, lora_id: str) -> None:
"""Abstract method for removing LoRA adapter."""
@abstractmethod
async def optim_step(
self,
adam_params: types.AdamParams,
lora_id: str,
) -> types.OptimStepResponse:
"""Abstract method for optimization step."""
@abstractmethod
async def save_state(
self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
) -> None:
"""Abstract method for saving model state."""
@abstractmethod
async def load_state(
self, lora_id: str, checkpoint_record: "CheckpointRecord", optimizer: bool
) -> None:
"""Abstract method for loading model state."""
@classmethod
def create_backend(cls, config: ModelConfig) -> "BaseTrainingBackend":
"""Factory method to create a training backend instance."""
if os.getenv("TUFT_CPU_TEST", "0") == "1":
from ..backends.training_backend import DummyTrainingBackend
return DummyTrainingBackend(config)
elif config.training_backend == "fsdp":
from ..backends.fsdp_training_backend import FSDPTrainingBackend
return FSDPTrainingBackend(config)
else:
from ..backends.training_backend import HFTrainingBackend
return HFTrainingBackend(config)