Skip to content

Commit a0e8e6c

Browse files
NicolasHugpmeier
andauthored
[Cherry-pick for 0.17.1] add gdown as optional requirement for dataset GDrive download (#8264)
Co-authored-by: Philip Meier <[email protected]>
1 parent b2383d4 commit a0e8e6c

File tree

8 files changed

+30
-67
lines changed

8 files changed

+30
-67
lines changed

.github/workflows/tests-schedule.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
run: pip install --no-build-isolation --editable .
3737

3838
- name: Install all optional dataset requirements
39-
run: pip install scipy pycocotools lmdb requests
39+
run: pip install scipy pycocotools lmdb gdown
4040

4141
- name: Install tests requirements
4242
run: pip install pytest

mypy.ini

+4
Original file line numberDiff line numberDiff line change
@@ -110,3 +110,7 @@ ignore_missing_imports = True
110110
[mypy-h5py.*]
111111

112112
ignore_missing_imports = True
113+
114+
[mypy-gdown.*]
115+
116+
ignore_missing_imports = True

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def write_version_file():
5959

6060
requirements = [
6161
"numpy",
62-
"requests",
6362
pytorch_dep,
6463
]
6564

torchvision/datasets/caltech.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class Caltech101(VisionDataset):
3030
download (bool, optional): If true, downloads the dataset from the internet and
3131
puts it in root directory. If dataset is already downloaded, it is not
3232
downloaded again.
33+
34+
.. warning::
35+
36+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
3337
"""
3438

3539
def __init__(

torchvision/datasets/celeba.py

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class CelebA(VisionDataset):
3838
download (bool, optional): If true, downloads the dataset from the internet and
3939
puts it in root directory. If dataset is already downloaded, it is not
4040
downloaded again.
41+
42+
.. warning::
43+
44+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
4145
"""
4246

4347
base_folder = "celeba"

torchvision/datasets/pcam.py

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class PCAM(VisionDataset):
2525
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2626
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
2727
dataset is already downloaded, it is not downloaded again.
28+
29+
.. warning::
30+
31+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
2832
"""
2933

3034
_FILES = {

torchvision/datasets/utils.py

+9-65
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import bz2
2-
import contextlib
32
import gzip
43
import hashlib
5-
import itertools
64
import lzma
75
import os
86
import os.path
@@ -13,13 +11,11 @@
1311
import urllib
1412
import urllib.error
1513
import urllib.request
16-
import warnings
1714
import zipfile
1815
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar
1916
from urllib.parse import urlparse
2017

2118
import numpy as np
22-
import requests
2319
import torch
2420
from torch.utils.model_zoo import tqdm
2521

@@ -187,22 +183,6 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
187183
return files
188184

189185

190-
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
191-
content = response.iter_content(chunk_size)
192-
first_chunk = None
193-
# filter out keep-alive new chunks
194-
while not first_chunk:
195-
first_chunk = next(content)
196-
content = itertools.chain([first_chunk], content)
197-
198-
try:
199-
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
200-
api_response = match["api_response"] if match is not None else None
201-
except UnicodeDecodeError:
202-
api_response = None
203-
return api_response, content
204-
205-
206186
def download_file_from_google_drive(file_id: str, root: str, filename: Optional[str] = None, md5: Optional[str] = None):
207187
"""Download a Google Drive file from and place it in root.
208188
@@ -212,7 +192,12 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
212192
filename (str, optional): Name to save the file under. If None, use the id of the file.
213193
md5 (str, optional): MD5 checksum of the download. If None, do not check
214194
"""
215-
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
195+
try:
196+
import gdown
197+
except ModuleNotFoundError:
198+
raise RuntimeError(
199+
"To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
200+
)
216201

217202
root = os.path.expanduser(root)
218203
if not filename:
@@ -225,51 +210,10 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
225210
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
226211
return
227212

228-
url = "https://drive.google.com/uc"
229-
params = dict(id=file_id, export="download")
230-
with requests.Session() as session:
231-
response = session.get(url, params=params, stream=True)
213+
gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)
232214

233-
for key, value in response.cookies.items():
234-
if key.startswith("download_warning"):
235-
token = value
236-
break
237-
else:
238-
api_response, content = _extract_gdrive_api_response(response)
239-
token = "t" if api_response == "Virus scan warning" else None
240-
241-
if token is not None:
242-
response = session.get(url, params=dict(params, confirm=token), stream=True)
243-
api_response, content = _extract_gdrive_api_response(response)
244-
245-
if api_response == "Quota exceeded":
246-
raise RuntimeError(
247-
f"The daily quota of the file {filename} is exceeded and it "
248-
f"can't be downloaded. This is a limitation of Google Drive "
249-
f"and can only be overcome by trying again later."
250-
)
251-
252-
_save_response_content(content, fpath)
253-
254-
# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
255-
if os.stat(fpath).st_size < 10 * 1024:
256-
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
257-
text = fh.read()
258-
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
259-
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
260-
warnings.warn(
261-
f"We detected some HTML elements in the downloaded file. "
262-
f"This most likely means that the download triggered an unhandled API response by GDrive. "
263-
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
264-
f"the response:\n\n{text}"
265-
)
266-
267-
if md5 and not check_md5(fpath, md5):
268-
raise RuntimeError(
269-
f"The MD5 checksum of the download file {fpath} does not match the one on record."
270-
f"Please delete the file and try again. "
271-
f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
272-
)
215+
if not check_integrity(fpath, md5):
216+
raise RuntimeError("File not found or corrupted.")
273217

274218

275219
def _extract_tar(from_path: str, to_path: str, compression: Optional[str]) -> None:

torchvision/datasets/widerface.py

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class WIDERFace(VisionDataset):
3434
puts it in root directory. If dataset is already downloaded, it is not
3535
downloaded again.
3636
37+
.. warning::
38+
39+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
40+
3741
"""
3842

3943
BASE_FOLDER = "widerface"

0 commit comments

Comments
 (0)