-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathhubconf.py
More file actions
164 lines (139 loc) · 5.53 KB
/
hubconf.py
File metadata and controls
164 lines (139 loc) · 5.53 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
dependencies = [
"torch",
"torchaudio",
"numpy",
"librosa",
"transformers",
"snac",
"msclap",
"safetensors"
]
import logging
import os
import torch
# Local MARS6 imports
from mars6_turbo.ar_model import Mars6_Turbo, SNACTokenizerInfo
from mars6_turbo.minbpe.regex import RegexTokenizer
########################################################################
# Checkpoint & tokenizer URLs
########################################################################
MARS6_CKPT_PT_URL = "https://github.com/Camb-ai/mars6-turbo/releases/download/v0.1/model-2000100.pt"
MARS6_TOKENIZER_URL = "https://github.com/Camb-ai/mars6-turbo/releases/download/v0.1/eng-tok-512.model"
def mars6_turbo(
pretrained: bool = True,
progress: bool = True,
device: str = None,
dtype: torch.dtype = torch.half,
ckpt_format: str = "pt",
checkpoint_url: str = None,
tokenizer_url: str = None,
):
"""
Torch Hub entry point for MARS6.
- pretrained: must be True if you want to load the pretrained model
- progress: whether to show download progress
- device: 'cuda' or 'cpu' (defaults to GPU if available)
- dtype: torch.half or torch.float
- ckpt_format: 'pt' or 'safetensors'
- checkpoint_url: optional override if hosting your own checkpoint
- tokenizer_url: optional override if hosting your own tokenizer
Returns:
(model, tokenizer) so you can run model.inference(...) or other code.
"""
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
assert ckpt_format in ["pt", "safetensors"], "ckpt_format must be 'pt' or 'safetensors'"
if not pretrained:
raise ValueError("Currently only pretrained MARS6 is supported.")
# Decide which URLs to use (or user-provided)
if checkpoint_url is None:
if ckpt_format == "safetensors":
checkpoint_url = MARS6_CKPT_SAFETENSORS_URL
else:
checkpoint_url = MARS6_CKPT_PT_URL
if tokenizer_url is None:
tokenizer_url = MARS6_TOKENIZER_URL
logging.info(f"Using device: {device}")
############################################################################
# 1) Load checkpoint
############################################################################
if ckpt_format == "safetensors":
ckpt = _load_safetensors_ckpt(checkpoint_url, progress)
else:
# standard .pt file
ckpt = torch.hub.load_state_dict_from_url(
checkpoint_url, progress=progress, check_hash=False, map_location="cpu"
)
model_sd = ckpt["model"]
model_cfg = ckpt["cfg"]
# remove 'module.' prefixes
new_sd = {}
for k, v in model_sd.items():
new_sd[k.replace("module.", "")] = v
############################################################################
# 2) Load tokenizer
############################################################################
_ = torch.hub.download_url_to_file(tokenizer_url, _cached_file_path(tokenizer_url), progress=progress)
texttok = RegexTokenizer()
texttok.load(_cached_file_path(tokenizer_url))
logging.info("Tokenizer loaded successfully.")
############################################################################
# 3) Build MARS6 model
############################################################################
text_vocab_size = len(texttok.vocab)
n_speech_vocab = SNACTokenizerInfo.codebook_size * 3 + SNACTokenizerInfo.n_snac_special
model = Mars6_Turbo(
n_input_vocab=text_vocab_size,
n_output_vocab=n_speech_vocab,
emb_dim=model_cfg.get("dim", 512),
n_layers=model_cfg.get("n_layers", 8),
fast_n_layers=model_cfg.get("fast_n_layers", 4),
n_langs=len(model_cfg.get("languages", ["en-us"]))
)
model.load_state_dict(new_sd)
model = model.to(device=device, dtype=dtype)
model.eval()
logging.info("MARS6 model loaded successfully.")
return model, texttok
def _load_safetensors_ckpt(url: str, progress: bool):
"""Load safetensors checkpoint from a URL, returning a normal Python dict with 'model' and 'cfg'."""
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.makedirs(model_dir, exist_ok=True)
filename = os.path.basename(url)
cached_file = os.path.join(model_dir, filename)
if not os.path.exists(cached_file):
# Download it
torch.hub.download_url_to_file(url, cached_file, None, progress=progress)
from safetensors import safe_open
ckpt = {}
with safe_open(cached_file, framework="pt", device="cpu") as f:
meta = f.metadata()
if meta is not None:
config_dict = {}
for k, v in meta.items():
try:
config_dict[k] = int(v)
except ValueError:
try:
config_dict[k] = float(v)
except ValueError:
config_dict[k] = v
ckpt["cfg"] = config_dict
else:
ckpt["cfg"] = {}
model_state = {}
for key in f.keys():
model_state[key] = f.get_tensor(key)
ckpt["model"] = model_state
return ckpt
def _cached_file_path(url: str) -> str:
"""
Returns the path to which Torch Hub will download `url`.
"""
hub_dir = torch.hub.get_dir()
model_dir = os.path.join(hub_dir, "checkpoints")
os.makedirs(model_dir, exist_ok=True)
filename = os.path.basename(url)
cached_file = os.path.join(model_dir, filename)
return cached_file