Skip to content

Commit ebc7e87

Browse files
author
pytorchbot
committed
2024-02-08 nightly release (81e2831)
1 parent 4197eab commit ebc7e87

File tree

9 files changed

+32
-68
lines changed

9 files changed

+32
-68
lines changed

.github/workflows/tests-schedule.yml

+1-1
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ jobs:
3636
run: pip install --no-build-isolation --editable .
3737

3838
- name: Install all optional dataset requirements
39-
run: pip install scipy pycocotools lmdb requests
39+
run: pip install scipy pycocotools lmdb gdown
4040

4141
- name: Install tests requirements
4242
run: pip install pytest

mypy.ini

+4
Original file line numberDiff line numberDiff line change
@@ -142,3 +142,7 @@ ignore_missing_imports = True
142142
[mypy-h5py.*]
143143

144144
ignore_missing_imports = True
145+
146+
[mypy-gdown.*]
147+
148+
ignore_missing_imports = True

setup.py

-1
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def write_version_file():
5959

6060
requirements = [
6161
"numpy",
62-
"requests",
6362
pytorch_dep,
6463
]
6564

torchvision/datasets/caltech.py

+4
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,10 @@ class Caltech101(VisionDataset):
3030
download (bool, optional): If true, downloads the dataset from the internet and
3131
puts it in root directory. If dataset is already downloaded, it is not
3232
downloaded again.
33+
34+
.. warning::
35+
36+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
3337
"""
3438

3539
def __init__(

torchvision/datasets/celeba.py

+4
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ class CelebA(VisionDataset):
3838
download (bool, optional): If true, downloads the dataset from the internet and
3939
puts it in root directory. If dataset is already downloaded, it is not
4040
downloaded again.
41+
42+
.. warning::
43+
44+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
4145
"""
4246

4347
base_folder = "celeba"

torchvision/datasets/pcam.py

+4
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,10 @@ class PCAM(VisionDataset):
2525
target_transform (callable, optional): A function/transform that takes in the target and transforms it.
2626
download (bool, optional): If True, downloads the dataset from the internet and puts it into ``root/pcam``. If
2727
dataset is already downloaded, it is not downloaded again.
28+
29+
.. warning::
30+
31+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
2832
"""
2933

3034
_FILES = {

torchvision/datasets/utils.py

+9-65
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
import bz2
2-
import contextlib
32
import gzip
43
import hashlib
5-
import itertools
64
import lzma
75
import os
86
import os.path
@@ -13,13 +11,11 @@
1311
import urllib
1412
import urllib.error
1513
import urllib.request
16-
import warnings
1714
import zipfile
1815
from typing import Any, Callable, Dict, IO, Iterable, Iterator, List, Optional, Tuple, TypeVar, Union
1916
from urllib.parse import urlparse
2017

2118
import numpy as np
22-
import requests
2319
import torch
2420
from torch.utils.model_zoo import tqdm
2521

@@ -191,22 +187,6 @@ def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False
191187
return files
192188

193189

194-
def _extract_gdrive_api_response(response, chunk_size: int = 32 * 1024) -> Tuple[bytes, Iterator[bytes]]:
195-
content = response.iter_content(chunk_size)
196-
first_chunk = None
197-
# filter out keep-alive new chunks
198-
while not first_chunk:
199-
first_chunk = next(content)
200-
content = itertools.chain([first_chunk], content)
201-
202-
try:
203-
match = re.search("<title>Google Drive - (?P<api_response>.+?)</title>", first_chunk.decode())
204-
api_response = match["api_response"] if match is not None else None
205-
except UnicodeDecodeError:
206-
api_response = None
207-
return api_response, content
208-
209-
210190
def download_file_from_google_drive(
211191
file_id: str,
212192
root: Union[str, pathlib.Path],
@@ -221,7 +201,12 @@ def download_file_from_google_drive(
221201
filename (str, optional): Name to save the file under. If None, use the id of the file.
222202
md5 (str, optional): MD5 checksum of the download. If None, do not check
223203
"""
224-
# Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
204+
try:
205+
import gdown
206+
except ModuleNotFoundError:
207+
raise RuntimeError(
208+
"To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
209+
)
225210

226211
root = os.path.expanduser(root)
227212
if not filename:
@@ -234,51 +219,10 @@ def download_file_from_google_drive(
234219
print(f"Using downloaded {'and verified ' if md5 else ''}file: {fpath}")
235220
return
236221

237-
url = "https://drive.google.com/uc"
238-
params = dict(id=file_id, export="download")
239-
with requests.Session() as session:
240-
response = session.get(url, params=params, stream=True)
222+
gdown.download(id=file_id, output=fpath, quiet=False, user_agent=USER_AGENT)
241223

242-
for key, value in response.cookies.items():
243-
if key.startswith("download_warning"):
244-
token = value
245-
break
246-
else:
247-
api_response, content = _extract_gdrive_api_response(response)
248-
token = "t" if api_response == "Virus scan warning" else None
249-
250-
if token is not None:
251-
response = session.get(url, params=dict(params, confirm=token), stream=True)
252-
api_response, content = _extract_gdrive_api_response(response)
253-
254-
if api_response == "Quota exceeded":
255-
raise RuntimeError(
256-
f"The daily quota of the file {filename} is exceeded and it "
257-
f"can't be downloaded. This is a limitation of Google Drive "
258-
f"and can only be overcome by trying again later."
259-
)
260-
261-
_save_response_content(content, fpath)
262-
263-
# In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
264-
if os.stat(fpath).st_size < 10 * 1024:
265-
with contextlib.suppress(UnicodeDecodeError), open(fpath) as fh:
266-
text = fh.read()
267-
# Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
268-
if re.search(r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)", text):
269-
warnings.warn(
270-
f"We detected some HTML elements in the downloaded file. "
271-
f"This most likely means that the download triggered an unhandled API response by GDrive. "
272-
f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
273-
f"the response:\n\n{text}"
274-
)
275-
276-
if md5 and not check_md5(fpath, md5):
277-
raise RuntimeError(
278-
f"The MD5 checksum of the download file {fpath} does not match the one on record."
279-
f"Please delete the file and try again. "
280-
f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
281-
)
224+
if not check_integrity(fpath, md5):
225+
raise RuntimeError("File not found or corrupted.")
282226

283227

284228
def _extract_tar(

torchvision/datasets/widerface.py

+4
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,10 @@ class WIDERFace(VisionDataset):
3434
puts it in root directory. If dataset is already downloaded, it is not
3535
downloaded again.
3636
37+
.. warning::
38+
39+
To download the dataset `gdown <https://github.com/wkentaro/gdown>`_ is required.
40+
3741
"""
3842

3943
BASE_FOLDER = "widerface"

torchvision/transforms/functional.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
import numpy as np
99
import torch
1010
from PIL import Image
11+
from PIL.Image import Image as PILImage
1112
from torch import Tensor
1213

1314
try:
@@ -123,7 +124,7 @@ def _is_numpy_image(img: Any) -> bool:
123124
return img.ndim in {2, 3}
124125

125126

126-
def to_tensor(pic) -> Tensor:
127+
def to_tensor(pic: Union[PILImage, np.ndarray]) -> Tensor:
127128
"""Convert a ``PIL Image`` or ``numpy.ndarray`` to tensor.
128129
This function does not support torchscript.
129130

0 commit comments

Comments
 (0)