1
1
import bz2
2
- import contextlib
3
2
import gzip
4
3
import hashlib
5
- import itertools
6
4
import lzma
7
5
import os
8
6
import os .path
13
11
import urllib
14
12
import urllib .error
15
13
import urllib .request
16
- import warnings
17
14
import zipfile
18
15
from typing import Any , Callable , Dict , IO , Iterable , Iterator , List , Optional , Tuple , TypeVar , Union
19
16
from urllib .parse import urlparse
20
17
21
18
import numpy as np
22
- import requests
23
19
import torch
24
20
from torch .utils .model_zoo import tqdm
25
21
@@ -191,22 +187,6 @@ def list_files(root: Union[str, pathlib.Path], suffix: str, prefix: bool = False
191
187
return files
192
188
193
189
194
- def _extract_gdrive_api_response (response , chunk_size : int = 32 * 1024 ) -> Tuple [bytes , Iterator [bytes ]]:
195
- content = response .iter_content (chunk_size )
196
- first_chunk = None
197
- # filter out keep-alive new chunks
198
- while not first_chunk :
199
- first_chunk = next (content )
200
- content = itertools .chain ([first_chunk ], content )
201
-
202
- try :
203
- match = re .search ("<title>Google Drive - (?P<api_response>.+?)</title>" , first_chunk .decode ())
204
- api_response = match ["api_response" ] if match is not None else None
205
- except UnicodeDecodeError :
206
- api_response = None
207
- return api_response , content
208
-
209
-
210
190
def download_file_from_google_drive (
211
191
file_id : str ,
212
192
root : Union [str , pathlib .Path ],
@@ -221,7 +201,12 @@ def download_file_from_google_drive(
221
201
filename (str, optional): Name to save the file under. If None, use the id of the file.
222
202
md5 (str, optional): MD5 checksum of the download. If None, do not check
223
203
"""
224
- # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
204
+ try :
205
+ import gdown
206
+ except ModuleNotFoundError :
207
+ raise RuntimeError (
208
+ "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
209
+ )
225
210
226
211
root = os .path .expanduser (root )
227
212
if not filename :
@@ -234,51 +219,10 @@ def download_file_from_google_drive(
234
219
print (f"Using downloaded { 'and verified ' if md5 else '' } file: { fpath } " )
235
220
return
236
221
237
- url = "https://drive.google.com/uc"
238
- params = dict (id = file_id , export = "download" )
239
- with requests .Session () as session :
240
- response = session .get (url , params = params , stream = True )
222
+ gdown .download (id = file_id , output = fpath , quiet = False , user_agent = USER_AGENT )
241
223
242
- for key , value in response .cookies .items ():
243
- if key .startswith ("download_warning" ):
244
- token = value
245
- break
246
- else :
247
- api_response , content = _extract_gdrive_api_response (response )
248
- token = "t" if api_response == "Virus scan warning" else None
249
-
250
- if token is not None :
251
- response = session .get (url , params = dict (params , confirm = token ), stream = True )
252
- api_response , content = _extract_gdrive_api_response (response )
253
-
254
- if api_response == "Quota exceeded" :
255
- raise RuntimeError (
256
- f"The daily quota of the file { filename } is exceeded and it "
257
- f"can't be downloaded. This is a limitation of Google Drive "
258
- f"and can only be overcome by trying again later."
259
- )
260
-
261
- _save_response_content (content , fpath )
262
-
263
- # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
264
- if os .stat (fpath ).st_size < 10 * 1024 :
265
- with contextlib .suppress (UnicodeDecodeError ), open (fpath ) as fh :
266
- text = fh .read ()
267
- # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
268
- if re .search (r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)" , text ):
269
- warnings .warn (
270
- f"We detected some HTML elements in the downloaded file. "
271
- f"This most likely means that the download triggered an unhandled API response by GDrive. "
272
- f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
273
- f"the response:\n \n { text } "
274
- )
275
-
276
- if md5 and not check_md5 (fpath , md5 ):
277
- raise RuntimeError (
278
- f"The MD5 checksum of the download file { fpath } does not match the one on record."
279
- f"Please delete the file and try again. "
280
- f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
281
- )
224
+ if not check_integrity (fpath , md5 ):
225
+ raise RuntimeError ("File not found or corrupted." )
282
226
283
227
284
228
def _extract_tar (
0 commit comments