|
1 | | -import urllib.request |
2 | 1 | from pathlib import Path |
3 | | -from threading import Thread |
4 | | -from urllib.error import HTTPError |
5 | 2 |
|
6 | | -from tqdm import tqdm |
| 3 | +from huggingface_hub import hf_hub_download |
7 | 4 |
|
| 5 | +HUGGINGFACE_REPO = "CorentinJ/SV2TTS" |
8 | 6 |
|
9 | 7 | default_models = { |
10 | | - "encoder": ("https://drive.google.com/uc?export=download&id=1q8mEGwCkFy23KZsinbuvdKAQLqNKbYf1", 17090379), |
11 | | - "synthesizer": ("https://drive.google.com/u/0/uc?id=1EqFMIbvxffxtjiVrtykroF6_mUh-5Z3s&export=download&confirm=t", 370554559), |
12 | | - "vocoder": ("https://drive.google.com/uc?export=download&id=1cf2NO6FtI0jDuy8AV3Xgn6leO6dHjIgu", 53845290), |
| 8 | + "encoder": 17090379, |
| 9 | + "synthesizer": 370554559, |
| 10 | + "vocoder": 53845290, |
13 | 11 | } |
14 | 12 |
|
15 | 13 |
|
16 | | -class DownloadProgressBar(tqdm): |
17 | | - def update_to(self, b=1, bsize=1, tsize=None): |
18 | | - if tsize is not None: |
19 | | - self.total = tsize |
20 | | - self.update(b * bsize - self.n) |
| 14 | +def _download_model(model_name: str, target_dir: Path): |
| 15 | + hf_hub_download( |
| 16 | + repo_id=HUGGINGFACE_REPO, |
| 17 | + revision="main", |
| 18 | + filename=f"{model_name}.pt", |
| 19 | + local_dir=str(target_dir), |
| 20 | + local_dir_use_symlinks=False, |
| 21 | + ) |
21 | 22 |
|
22 | 23 |
|
23 | | -def download(url: str, target: Path, bar_pos=0): |
24 | | - # Ensure the directory exists |
25 | | - target.parent.mkdir(exist_ok=True, parents=True) |
26 | | - |
27 | | - desc = f"Downloading {target.name}" |
28 | | - with DownloadProgressBar(unit="B", unit_scale=True, miniters=1, desc=desc, position=bar_pos, leave=False) as t: |
29 | | - try: |
30 | | - urllib.request.urlretrieve(url, filename=target, reporthook=t.update_to) |
31 | | - except HTTPError: |
32 | | - return |
| 24 | +def ensure_default_models(models_dir: Path): |
| 25 | + target_dir = models_dir / "default" |
| 26 | + target_dir.mkdir(parents=True, exist_ok=True) |
33 | 27 |
|
| 28 | + for model_name, expected_size in default_models.items(): |
| 29 | + target_path = target_dir / f"{model_name}.pt" |
34 | 30 |
|
35 | | -def ensure_default_models(models_dir: Path): |
36 | | - # Define download tasks |
37 | | - jobs = [] |
38 | | - for model_name, (url, size) in default_models.items(): |
39 | | - target_path = models_dir / "default" / f"{model_name}.pt" |
40 | 31 | if target_path.exists(): |
41 | | - if target_path.stat().st_size != size: |
42 | | - print(f"File {target_path} is not of expected size, redownloading...") |
43 | | - else: |
| 32 | + if target_path.stat().st_size == expected_size: |
44 | 33 | continue |
| 34 | + print(f"File {target_path} is not of expected size, redownloading...") |
45 | 35 |
|
46 | | - thread = Thread(target=download, args=(url, target_path, len(jobs))) |
47 | | - thread.start() |
48 | | - jobs.append((thread, target_path, size)) |
49 | | - |
50 | | - # Run and join threads |
51 | | - for thread, target_path, size in jobs: |
52 | | - thread.join() |
| 36 | + _download_model(model_name, target_dir) |
53 | 37 |
|
54 | | - assert target_path.exists() and target_path.stat().st_size == size, \ |
55 | | - f"Download for {target_path.name} failed. You may download models manually instead.\n" \ |
56 | | - f"https://drive.google.com/drive/folders/1fU6umc5uQAVR2udZdHX-lDgXYzTyqG_j" |
| 38 | + assert target_path.exists() and target_path.stat().st_size == expected_size, ( |
| 39 | + f"Download for {target_path.name} failed. You may download models manually instead.\n" |
| 40 | + f"https://huggingface.co/{HUGGINGFACE_REPO}" |
| 41 | + ) |
0 commit comments