-
Notifications
You must be signed in to change notification settings - Fork 15
Expand file tree
/
Copy pathauto_model.py
More file actions
92 lines (86 loc) · 4.41 KB
/
auto_model.py
File metadata and controls
92 lines (86 loc) · 4.41 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
from .llama import Llama, LlamaAwq, LlamaOffload, LlamaAwqOffload, LlamaCudagraph
from .qwen import Qwen, QwenOffload, QwenCudagraph
class AutoModelLM:
"""
自动模型加载器,根据模型类型动态加载对应的类。
"""
_OFFLOAD_MODEL_MAPPING = {
"ibnzterrell/Meta-Llama-3.3-70B-Instruct-AWQ-INT4": LlamaAwqOffload,
"lambdalabs/Llama-3.3-70B-Instruct-AWQ-4bit": LlamaAwqOffload,
"casperhansen/llama-3.3-70b-instruct-awq": LlamaAwqOffload,
"hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4": LlamaAwqOffload,
"hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": LlamaAwqOffload,
"meta-llama/Llama-3.3-70B-Instruct": LlamaOffload,
"meta-llama/Llama-3.1-70B-Instruct": LlamaOffload,
"meta-llama/Llama-3.1-8B-Instruct": LlamaOffload,
"meta-llama/Meta-Llama-3-70B-Instruct": LlamaOffload,
"meta-llama/Meta-Llama-3-8B-Instruct": LlamaOffload,
"Qwen/Qwen2.5-3B-Instruct": QwenOffload,
"Qwen/Qwen2.5-0.5B-Instruct": QwenOffload
}
_MODEL_MAPPING = {
"ibnzterrell/Meta-Llama-3.3-70B-Instruct-AWQ-INT4": LlamaAwq,
"lambdalabs/Llama-3.3-70B-Instruct-AWQ-4bit": LlamaAwq,
"casperhansen/llama-3.3-70b-instruct-awq": LlamaAwq,
"hugging-quants/Meta-Llama-3.1-70B-Instruct-AWQ-INT4": LlamaAwq,
"hugging-quants/Meta-Llama-3.1-8B-Instruct-AWQ-INT4": LlamaAwq,
"meta-llama/Llama-3.3-70B-Instruct": Llama,
"meta-llama/Llama-3.1-70B-Instruct": Llama,
"meta-llama/Llama-3.1-8B-Instruct": Llama,
"meta-llama/Meta-Llama-3-70B-Instruct": Llama,
"meta-llama/Meta-Llama-3-8B-Instruct": Llama,
"meta-llama/Llama-3.2-1B-Instruct": Llama,
"meta-llama/Llama-3.2-3B-Instruct": Llama,
"Felladrin/Llama-68M-Chat-v1": Llama,
"facebook/layerskip-llama3.2-1B": Llama,
"Zhuominc/Llama-3-330M": Llama,
"Zhuominc/Coder-670M": Llama,
"Zhuominc/Coder-400M": Llama,
"Zhuominc/Coder-400M-IT": Llama,
"Zhuominc/FastCode-500M": Llama,
"InfiniAILab/CodeDrafter-500M": Llama,
"Qwen/Qwen2.5-3B-Instruct": Qwen,
"Qwen/Qwen2.5-0.5B-Instruct": Qwen
}
_CUDAGRAPH_MODEL_MAPPING = {
"meta-llama/Llama-3.1-8B-Instruct": LlamaCudagraph,
"meta-llama/Meta-Llama-3-8B-Instruct": LlamaCudagraph,
"meta-llama/Llama-3.2-1B-Instruct": LlamaCudagraph,
"meta-llama/Llama-3.2-3B-Instruct": LlamaCudagraph,
"Felladrin/Llama-68M-Chat-v1": LlamaCudagraph,
"facebook/layerskip-llama3.2-1B": LlamaCudagraph,
"Zhuominc/Llama-3-330M": LlamaCudagraph,
"Zhuominc/Coder-670M": LlamaCudagraph,
"Zhuominc/Coder-400M": LlamaCudagraph,
"Zhuominc/Coder-400M-IT": LlamaCudagraph,
"Zhuominc/FastCode-500M": LlamaCudagraph,
"InfiniAILab/CodeDrafter-500M": LlamaCudagraph,
"Qwen/Qwen2.5-3B-Instruct": QwenCudagraph,
"Qwen/Qwen2.5-0.5B-Instruct": QwenCudagraph
}
@classmethod
def from_pretrained(cls, model_name, offload=False, cuda_graph=False, **kwargs):
"""
根据模型类型加载预训练模型。
:param model_name: 模型类型,例如 'llama' 或 'gpt'
:param kwargs: 额外参数
:return: 对应的模型实例
"""
if cuda_graph:
if model_name not in cls._CUDAGRAPH_MODEL_MAPPING:
raise ValueError(f"Model type '{model_name}' is not supported. "
f"Supported types: {list(cls._CUDAGRAPH_MODEL_MAPPING.keys())}")
model_class = cls._CUDAGRAPH_MODEL_MAPPING[model_name]
return model_class(model_name = model_name, **kwargs)
if not offload:
if model_name not in cls._MODEL_MAPPING:
raise ValueError(f"Model type '{model_name}' is not supported. "
f"Supported types: {list(cls._MODEL_MAPPING.keys())}")
model_class = cls._MODEL_MAPPING[model_name]
return model_class(model_name = model_name, **kwargs)
else:
if model_name not in cls._OFFLOAD_MODEL_MAPPING:
raise ValueError(f"Model type '{model_name}' is not supported (offload). "
f"Supported (offload) types: {list(cls._OFFLOAD_MODEL_MAPPING.keys())}")
model_class = cls._OFFLOAD_MODEL_MAPPING[model_name]
return model_class(model_name = model_name, **kwargs)