-
Notifications
You must be signed in to change notification settings - Fork 309
Expand file tree
/
Copy pathadapter.py
More file actions
171 lines (140 loc) · 6.42 KB
/
adapter.py
File metadata and controls
171 lines (140 loc) · 6.42 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
import warnings
from dataclasses import dataclass
from functools import lru_cache
from typing import TYPE_CHECKING, Set, Tuple
from loguru import logger
from safetensors.torch import load_file
from transformers import AutoConfig, AutoTokenizer, PreTrainedTokenizer
from lorax_server.pb import generate_pb2
from lorax_server.utils.merges.strategies import merge_adapters
from lorax_server.utils.sources import get_config_path, get_model_source
if TYPE_CHECKING:
from lorax_server.adapters.config import AdapterConfig, ModuleMap
BASE_MODEL_ADAPTER_ID = "__base_model__"
@dataclass
class AdapterParametersContainer:
adapter_parameters: generate_pb2.AdapterParameters
adapter_source: str
adapter_index: int
def __hash__(self) -> int:
return self.adapter_index
def is_base_model(adapter_parameters: generate_pb2.AdapterParameters) -> bool:
if len(adapter_parameters.adapter_ids) != 1:
return False
return adapter_parameters.adapter_ids[0] == BASE_MODEL_ADAPTER_ID
def load_and_merge_adapters(
model_id: str,
adapter_parameters: generate_pb2.AdapterParameters,
adapter_source: str,
adapter_index: int,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
if len(adapter_parameters.adapter_ids) == 1:
return load_module_map(
model_id, adapter_parameters.adapter_ids[0], adapter_source, weight_names, api_token, trust_remote_code
)
adapter_params = AdapterParametersContainer(adapter_parameters, adapter_source, adapter_index)
return _load_and_merge(model_id, adapter_params, weight_names, api_token, trust_remote_code)
@lru_cache(maxsize=32)
def _load_and_merge(
model_id: str,
adapter_params: AdapterParametersContainer,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
params = adapter_params.adapter_parameters
adapters_to_merge = []
merged_weight_names = set()
tokenizer = None
for adapter_id in params.adapter_ids:
if adapter_id == BASE_MODEL_ADAPTER_ID:
raise ValueError("Base model adapter cannot be merged.")
module_map, adapter_config, adapter_weight_names, adapter_tokenizer = load_module_map(
model_id,
adapter_id,
adapter_params.adapter_source,
weight_names,
api_token,
trust_remote_code,
)
adapters_to_merge.append((module_map, adapter_config))
merged_weight_names = merged_weight_names.union(adapter_weight_names)
if tokenizer is None:
tokenizer = adapter_tokenizer
if len(adapters_to_merge) == 0:
raise ValueError("No adapters to merge.")
module_map, adapter_config = merge_adapters(adapters_to_merge, params)
return module_map, adapter_config, merged_weight_names, tokenizer
def check_architectures(
model_id: str,
adapter_id: str,
adapter_config: "AdapterConfig",
trust_remote_code: bool = False,
):
try:
if not adapter_config.base_model_name_or_path:
# Avoid execuation latency caused by the network connection retrying for AutoConfig.from_pretrained(None)
return
expected_config = AutoConfig.from_pretrained(model_id, trust_remote_code=trust_remote_code)
model_config = AutoConfig.from_pretrained(
adapter_config.base_model_name_or_path, trust_remote_code=trust_remote_code
)
except Exception as e:
warnings.warn(
f"Unable to check architecture compatibility for adapter '{adapter_id}' "
f"against model '{model_id}'. Assuming they are compatible. Error: {e}"
)
return
if model_config.architectures == expected_config.architectures:
warnings.warn(
f"Adapter '{adapter_id}' was not trained on base model '{model_id}'. "
f"If you encounter issues, use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
else:
# TODO(travis): revisit this when we support clasification heads which will not use CausalLM
raise ValueError(
f"Adapter '{adapter_id}' is not compatible with model '{model_id}'. "
f"Architectures differ: {model_config.architectures} != {expected_config.architectures}. "
f"Use --model-id '{adapter_config.base_model_name_or_path}' instead."
)
@lru_cache(maxsize=128)
def load_module_map(
model_id: str,
adapter_id: str,
adapter_source: str,
weight_names: Tuple[str],
api_token: str,
trust_remote_code: bool = False,
) -> Tuple["ModuleMap", "AdapterConfig", Set[str], PreTrainedTokenizer]:
# TODO(geoffrey): refactor this and merge parts of this function with
# lorax_server/utils/adapter.py::create_merged_weight_files
source = get_model_source(adapter_source, adapter_id, extension=".safetensors", api_token=api_token)
config_path = get_config_path(adapter_id, adapter_source)
adapter_config = source.load_config()
if adapter_config.base_model_name_or_path != model_id:
check_architectures(model_id, adapter_id, adapter_config, trust_remote_code)
try:
adapter_tokenizer = AutoTokenizer.from_pretrained(
config_path, token=api_token, trust_remote_code=trust_remote_code
)
except Exception:
# Adapter does not have a tokenizer, so fallback to base model tokenizer
adapter_tokenizer = None
# load adapter weights from all shards (should have relatively small memory footprint)
adapter_filenames = source.weight_files()
adapter_weights = {}
for filename in adapter_filenames:
adapter_weights.update(load_file(filename))
# map the model weights to the relevant adapter weights (LoRA A and B matrices)
module_map, adapter_weight_names = adapter_config.map_weights_for_model(adapter_weights, weight_names)
# note(ajinkya): adapter weights are consumed during above mapping but if some are not then we may not be
# supporting all the weights in the adapter which should be an error but for now just logging it
if len(adapter_weights) > 0:
logger.warning(
f"Adapter {adapter_id} for the model {model_id}" + \
f" contains unsupported weights: {', '.join(adapter_weights.keys())}"
)
return module_map, adapter_config, adapter_weight_names, adapter_tokenizer