-
Notifications
You must be signed in to change notification settings - Fork 439
Expand file tree
/
Copy pathdev.py
More file actions
151 lines (123 loc) · 4.96 KB
/
dev.py
File metadata and controls
151 lines (123 loc) · 4.96 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import contextlib
import logging
import os
import tempfile
from functools import wraps
from typing import Type
import torch
from compressed_tensors.offload import dispatch_model
from compressed_tensors.utils import patch_attr
from huggingface_hub import snapshot_download
from loguru import logger
from safetensors.torch import save_file
from transformers import AutoModelForCausalLM, PreTrainedModel
try:
from transformers.modeling_utils import TORCH_INIT_FUNCTIONS
except ImportError: # transformers>=5 moved this
from transformers.initialization import TORCH_INIT_FUNCTIONS
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, WEIGHTS_INDEX_NAME
__all__ = [
"skip_weights_download",
"patch_transformers_logger_level",
"get_main_device",
"dispatch_for_generation",
]
@contextlib.contextmanager
def skip_weights_download(model_class: Type[PreTrainedModel] = AutoModelForCausalLM):
"""
Context manager under which models are initialized without having to download
the model weight files. This differs from `init_empty_weights` in that weights are
allocated on to assigned devices with random values, as opposed to being on the meta
device
:param model_class: class to patch, defaults to `AutoModelForCausalLM`
"""
original_fn = model_class.from_pretrained
weights_files = [
"*.bin",
"*.safetensors",
"*.pth",
SAFE_WEIGHTS_INDEX_NAME,
WEIGHTS_INDEX_NAME,
"*.msgpack",
"*.pt",
]
@classmethod
def patched(cls, *args, **kwargs):
nonlocal tmp_dir
# intercept model stub
model_stub = args[0] if args else kwargs.pop("pretrained_model_name_or_path")
# download files into tmp dir
os.makedirs(tmp_dir, exist_ok=True)
snapshot_download(
repo_id=model_stub, local_dir=tmp_dir, ignore_patterns=weights_files
)
# make an empty weights file to avoid errors
weights_file_path = os.path.join(tmp_dir, "model.safetensors")
save_file({}, weights_file_path, metadata={"format": "pt"})
# load from tmp dir
model = original_fn(tmp_dir, **kwargs)
# replace model_path
model.name_or_path = model_stub
model.config._name_or_path = model_stub
return model
with (
tempfile.TemporaryDirectory() as tmp_dir,
patch_attr(model_class, "from_pretrained", patched),
skip_weights_initialize(),
patch_transformers_logger_level(),
):
yield
@contextlib.contextmanager
def skip_weights_initialize(use_zeros: bool = False):
"""
Very similar to `transformers.model_utils.no_init_weights`, except that torch.Tensor
initialization functions are also patched to account for tensors which are
initialized not on the meta device
"""
def skip(tensor: torch.Tensor, *args, **kwargs) -> torch.Tensor:
if use_zeros:
return tensor.fill_(0)
return tensor
with contextlib.ExitStack() as stack:
for name in TORCH_INIT_FUNCTIONS.keys():
stack.enter_context(patch_attr(torch.nn.init, name, skip))
stack.enter_context(patch_attr(torch.Tensor, name, skip))
yield
@contextlib.contextmanager
def patch_transformers_logger_level(level: int = logging.ERROR):
"""
Context under which the transformers logger's level is modified
This can be used with `skip_weights_download` to squelch warnings related to
missing parameters in the checkpoint
:param level: new logging level for transformers logger. Logs whose level is below
this level will not be logged
"""
transformers_logger = logging.getLogger("transformers.modeling_utils")
restore_log_level = transformers_logger.getEffectiveLevel()
transformers_logger.setLevel(level=level)
yield
transformers_logger.setLevel(level=restore_log_level)
def get_main_device() -> torch.device:
if torch.cuda.is_available():
return torch.device("cuda:0")
elif hasattr(torch, "xpu") and torch.xpu.is_available():
return torch.device("xpu:0")
else:
logger.warning("CUDA/XPU is not available! Compressing model on CPU instead")
return torch.device("cpu")
@wraps(dispatch_model)
def dispatch_for_generation(*args, **kwargs) -> PreTrainedModel:
"""
Dispatch a model autoregressive generation. This means that modules are dispatched
evenly across avaiable devices and kept onloaded if possible.
:param model: model to dispatch
:param hint_batch_size: reserve memory for batch size of inputs
:param hint_batch_seq_len: reserve memory for sequence of length of inputs
:param hint_model_dtype: reserve memory for model's dtype.
Will be inferred from model if none is provided
:param hint_extra_memory: extra memory reserved for model serving
:param no_split_modules: names of module classes which should not be split
across multiple devices
:return: dispatched model
"""
return dispatch_model(*args, **kwargs)