-
Notifications
You must be signed in to change notification settings - Fork 2.3k
Expand file tree
/
Copy pathdataclasses.py
More file actions
executable file
·254 lines (224 loc) · 9.98 KB
/
dataclasses.py
File metadata and controls
executable file
·254 lines (224 loc) · 9.98 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
from transformers import AutoConfig
from typing import Optional, Literal
from pydantic import AliasPath, BaseModel, Field, AliasChoices, model_validator
import huggingface_hub
from huggingface_hub.constants import (
SAFETENSORS_INDEX_FILE,
SAFETENSORS_MAX_HEADER_LENGTH,
SAFETENSORS_SINGLE_FILE,
)
from huggingface_hub.utils import SafetensorsRepoMetadata, SafetensorsFileMetadata, TensorInfo
from huggingface_hub.utils import tqdm as hf_tqdm
from tqdm.contrib.concurrent import thread_map
import os
import json
import struct
def parse_safetensors_file_metadata(model_path, filename):
with open(os.path.join(model_path, filename), "rb") as f:
metadata_size = f.read(8)
metadata_size = struct.unpack("<Q", metadata_size)[0]
if metadata_size > SAFETENSORS_MAX_HEADER_LENGTH:
raise RuntimeError(
f"Failed to parse safetensors header for '{filename}' (model_path '{model_path}'): "
f"safetensors header is too big. Maximum supported size is "
f"{SAFETENSORS_MAX_HEADER_LENGTH} bytes (got {metadata_size}).")
metadata_as_bytes = f.read(metadata_size)
try:
metadata_as_dict = json.loads(metadata_as_bytes.decode(errors="ignore"))
except json.JSONDecodeError as e:
raise RuntimeError(
f"Failed to parse safetensors header for '{filename}' (model_path '{model_path}'): "
"header format not recognized. Please make sure this is a correctly formatted safetensors file."
) from e
try:
return SafetensorsFileMetadata(
metadata=metadata_as_dict.get("__metadata__", {}),
tensors={
key:
TensorInfo(
dtype=tensor["dtype"],
shape=tensor["shape"],
data_offsets=tuple(tensor["data_offsets"]), # type: ignore
)
for key, tensor in metadata_as_dict.items()
if key != "__metadata__"
},
)
except (KeyError, IndexError) as e:
raise RuntimeError(
f"Failed to parse safetensors header for '{filename}' (model_path '{model_path}'): "
"header format not recognized. Please make sure this is a correctly formatted safetensors file."
) from e
def get_safetensors_metadata(model_name_or_path):
""" Read the safetensors metadata from HF model. """
if os.path.isdir(model_name_or_path):
if os.path.exists(
os.path.join(model_name_or_path, SAFETENSORS_SINGLE_FILE)):
file_metadata = parse_safetensors_file_metadata(
model_path=model_name_or_path, filename=SAFETENSORS_SINGLE_FILE)
return SafetensorsRepoMetadata(
metadata=None,
sharded=False,
weight_map={
tensor_name: SAFETENSORS_SINGLE_FILE
for tensor_name in file_metadata.tensors.keys()
},
files_metadata={SAFETENSORS_SINGLE_FILE: file_metadata},
)
elif os.path.exists(
os.path.join(model_name_or_path, SAFETENSORS_INDEX_FILE)):
with open(os.path.join(model_name_or_path,
SAFETENSORS_INDEX_FILE)) as f:
index = json.load(f)
weight_map = index.get("weight_map", {})
# Fetch metadata per shard
files_metadata = {}
def _parse(filename: str) -> None:
files_metadata[filename] = parse_safetensors_file_metadata(
model_path=model_name_or_path, filename=filename)
thread_map(
_parse,
set(weight_map.values()),
desc="Parse safetensors files",
tqdm_class=hf_tqdm,
)
return SafetensorsRepoMetadata(
metadata=index.get("metadata", None),
sharded=True,
weight_map=weight_map,
files_metadata=files_metadata,
)
else:
# Not a safetensors repo
raise RuntimeError(
f"'{model_name_or_path}' is not a safetensors repo. Couldn't find '{SAFETENSORS_INDEX_FILE}' or '{SAFETENSORS_SINGLE_FILE}' files."
)
else:
return huggingface_hub.get_safetensors_metadata(model_name_or_path)
class ModelConfig(BaseModel):
""" Model specific configurations. The parameters are needed in engine
setting calculation.
"""
name: str
model_type: str
param_count: int
num_hidden_layers: int = Field(validation_alias=AliasChoices(
"num_hidden_layers",
"n_layer",
AliasPath("text_config", "num_hidden_layers"),
AliasPath("language_config", "num_hidden_layers"),
))
num_attention_layers: Optional[int] = Field(default=None)
num_attention_heads: int = Field(validation_alias=AliasChoices(
"num_attention_heads",
"n_head",
AliasPath("text_config", "num_attention_heads"),
AliasPath("language_config", "num_attention_heads"),
))
num_key_value_heads: Optional[int] = Field(
default=None,
validation_alias=AliasChoices(
"num_key_value_heads",
"num_kv_heads",
AliasPath("text_config", "num_key_value_heads"),
AliasPath("language_config", "num_key_value_heads"),
),
)
hidden_size: int = Field(validation_alias=AliasChoices(
"hidden_size",
"n_embd",
AliasPath("text_config", "hidden_size"),
))
head_size: Optional[int] = Field(default=None,
validation_alias=AliasChoices(
"head_size",
"head_dim",
"attention_head_dim",
AliasPath("text_config", "head_dim"),
))
max_position_embeddings: Optional[int] = Field(
default=None,
validation_alias=AliasChoices(
"max_position_embeddings",
"n_positions",
AliasPath("text_config", "max_position_embeddings"),
))
dtype: Literal["float16", "bfloat16", "float32",
None] = Field(default="float16",
validation_alias=AliasChoices(
"dtype", "torch_dtype"))
@model_validator(mode="after")
def set_values_if_none(self):
""" Set the values if cannot get values from HF config.json. """
if not self.dtype: # for GPT-J
self.dtype = "float16"
if self.num_key_value_heads is None:
self.num_key_value_heads = self.num_attention_heads
if self.head_size is None:
self.head_size = self.hidden_size // self.num_attention_heads
if self.num_attention_layers is None:
self.num_attention_layers = self.num_hidden_layers
return self
@classmethod
def get_param_count(cls, model_hf_name, hf_model_path):
""" Read the parameter count from HF safetensor metadata. """
if model_hf_name == "EleutherAI/gpt-j-6b": # GPT-J repo doesn't use safetensor format.
param_count = 6053381344
else:
model_name_or_path = hf_model_path or model_hf_name
metadata = get_safetensors_metadata(model_name_or_path)
param_count = sum(metadata.parameter_count.values())
assert param_count, f"Can't get valid parameter count for model: {model_name_or_path}."
return param_count
@classmethod
def from_hf(cls, model_hf_name, hf_model_path):
model_name_or_path = hf_model_path or model_hf_name
hf_config = AutoConfig.from_pretrained(
model_name_or_path, trust_remote_code=True).to_dict()
param_count = cls.get_param_count(model_hf_name, hf_model_path)
return cls(name=model_hf_name, param_count=param_count, **hf_config)
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
return 0
def cache_memory_fraction(self, cache_memory_fraction):
return cache_memory_fraction
class NemotronHybridConfig(ModelConfig):
hybrid_override_pattern: str
d_state: int = Field(validation_alias=AliasChoices(
"d_state",
"mamba_d_state",
"ssm_state_size",
))
d_conv: int = Field(validation_alias=AliasChoices(
"d_conv",
"mamba_d_conv",
"conv_kernel",
))
mamba_num_heads: int
n_groups: int
mamba_head_dim: int
d_inner: Optional[int] = Field(default=None)
num_mamba_layers: Optional[int] = Field(default=None)
mamba_ssm_cache_dtype: Optional[str] = Field(default="auto")
@model_validator(mode="after")
def set_values_if_none(self):
""" Set the values if cannot get values from HF config.json. """
if not self.d_inner:
self.d_inner = self.mamba_num_heads * self.mamba_head_dim
if self.num_mamba_layers is None:
self.num_mamba_layers = self.hybrid_override_pattern.count("M")
if self.num_attention_layers is None:
self.num_attention_layers = self.hybrid_override_pattern.count("*")
super().set_values_if_none()
return self
def extra_model_cache_in_gb(self, bytes_per_elem, target_seq_len=None):
conv_dim = self.d_inner + 2 * self.n_groups * self.d_state
conv_state_elems = conv_dim * (self.d_conv - 1)
ssm_state_elems = self.mamba_num_heads * self.mamba_head_dim * self.d_state
gb_per_mamba_cache = bytes_per_elem * self.num_mamba_layers * (
conv_state_elems + ssm_state_elems) / (1024**3)
return gb_per_mamba_cache
def cache_memory_fraction(self, cache_memory_fraction):
# Each mamba cache entry is pretty large (~50MB for 8B model), so we are more conservative when estimating the max batch size
return cache_memory_fraction**2
def set_mamba_ssm_cache_dtype(self, mamba_ssm_cache_dtype: str):
self.mamba_ssm_cache_dtype = mamba_ssm_cache_dtype