-
-
Notifications
You must be signed in to change notification settings - Fork 144
Expand file tree
/
Copy pathmodel.py
More file actions
184 lines (131 loc) · 5.38 KB
/
model.py
File metadata and controls
184 lines (131 loc) · 5.38 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
"""
Manages the storage and utility of model containers.
Containers exist as a common interface for backends.
"""
import pathlib
from enum import Enum
from fastapi import HTTPException
from loguru import logger
from typing import Optional
import time
from common.logger import get_loading_progress_bar
from common.networking import handle_request_error
from common.tabby_config import config
from common.optional_dependencies import dependencies
if dependencies.exllamav2:
from backends.exllamav2.model import ExllamaV2Container
# Global model container
container: Optional[ExllamaV2Container] = None
embeddings_container = None
if dependencies.extras:
from backends.infinity.model import InfinityContainer
embeddings_container: Optional[InfinityContainer] = None
class ModelType(Enum):
MODEL = "model"
DRAFT = "draft"
EMBEDDING = "embedding"
def load_progress(module, modules):
"""Wrapper callback for load progress."""
yield module, modules
async def unload_model(skip_wait: bool = False, shutdown: bool = False):
"""Unloads a model"""
global container
await container.unload(skip_wait=skip_wait, shutdown=shutdown)
container = None
async def load_model_gen(model_path: pathlib.Path, **kwargs):
"""Generator to load a model"""
global container
# Check if the model is already loaded
if container and container.model:
loaded_model_name = container.model_dir.name
if loaded_model_name == model_path.name and container.model_loaded:
raise ValueError(
f'Model "{loaded_model_name}" is already loaded! Aborting.'
)
logger.info("Unloading existing model.")
await unload_model()
# Merge with config defaults
kwargs = {**config.model_defaults, **kwargs}
# Create a new container
container = await ExllamaV2Container.create(model_path.resolve(), False, **kwargs)
model_type = "draft" if container.draft_config else "model"
load_status = container.load_gen(load_progress, **kwargs)
progress = get_loading_progress_bar()
model_loading_started = time.time()
progress.start()
try:
async for module, modules in load_status:
if module == 0:
loading_task = progress.add_task(
f"[cyan]Loading {model_type} modules", total=modules
)
else:
progress.advance(loading_task)
yield module, modules, model_type
if module == modules:
# Switch to model progress if the draft model is loaded
if model_type == "draft":
model_type = "model"
else:
progress.stop()
finally:
progress.stop()
model_loading_time = time.time() - model_loading_started
logger.info("Model loading took {:.2f} seconds.".format(model_loading_time))
async def load_model(model_path: pathlib.Path, **kwargs):
async for _ in load_model_gen(model_path, **kwargs):
pass
async def load_loras(lora_dir, **kwargs):
"""Wrapper to load loras."""
if len(container.get_loras()) > 0:
await unload_loras()
return await container.load_loras(lora_dir, **kwargs)
async def unload_loras():
"""Wrapper to unload loras"""
await container.unload(loras_only=True)
async def load_embedding_model(model_path: pathlib.Path, **kwargs):
global embeddings_container
# Break out if infinity isn't installed
if not dependencies.extras:
raise ImportError(
"Skipping embeddings because infinity-emb is not installed.\n"
"Please run the following command in your environment "
"to install extra packages:\n"
"pip install -U .[extras]"
)
# Check if the model is already loaded
if embeddings_container and embeddings_container.engine:
loaded_model_name = embeddings_container.model_dir.name
if loaded_model_name == model_path.name and embeddings_container.model_loaded:
raise ValueError(
f'Embeddings model "{loaded_model_name}" is already loaded! Aborting.'
)
logger.info("Unloading existing embeddings model.")
await unload_embedding_model()
embeddings_container = InfinityContainer(model_path)
await embeddings_container.load(**kwargs)
async def unload_embedding_model():
global embeddings_container
await embeddings_container.unload()
embeddings_container = None
async def check_model_container():
"""FastAPI depends that checks if a model isn't loaded or currently loading."""
if container is None or not (container.model_is_loading or container.model_loaded):
error_message = handle_request_error(
"No models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)
async def check_embeddings_container():
"""
FastAPI depends that checks if an embeddings model is loaded.
This is the same as the model container check, but with embeddings instead.
"""
if embeddings_container is None or not (
embeddings_container.model_is_loading or embeddings_container.model_loaded
):
error_message = handle_request_error(
"No embedding models are currently loaded.",
exc_info=False,
).error.message
raise HTTPException(400, error_message)