1
1
import bz2
2
- import contextlib
3
2
import gzip
4
3
import hashlib
5
- import itertools
6
4
import lzma
7
5
import os
8
6
import os .path
13
11
import urllib
14
12
import urllib .error
15
13
import urllib .request
16
- import warnings
17
14
import zipfile
18
15
from typing import Any , Callable , Dict , IO , Iterable , Iterator , List , Optional , Tuple , TypeVar
19
16
from urllib .parse import urlparse
20
17
21
18
import numpy as np
22
- import requests
23
19
import torch
24
20
from torch .utils .model_zoo import tqdm
25
21
@@ -187,22 +183,6 @@ def list_files(root: str, suffix: str, prefix: bool = False) -> List[str]:
187
183
return files
188
184
189
185
190
- def _extract_gdrive_api_response (response , chunk_size : int = 32 * 1024 ) -> Tuple [bytes , Iterator [bytes ]]:
191
- content = response .iter_content (chunk_size )
192
- first_chunk = None
193
- # filter out keep-alive new chunks
194
- while not first_chunk :
195
- first_chunk = next (content )
196
- content = itertools .chain ([first_chunk ], content )
197
-
198
- try :
199
- match = re .search ("<title>Google Drive - (?P<api_response>.+?)</title>" , first_chunk .decode ())
200
- api_response = match ["api_response" ] if match is not None else None
201
- except UnicodeDecodeError :
202
- api_response = None
203
- return api_response , content
204
-
205
-
206
186
def download_file_from_google_drive (file_id : str , root : str , filename : Optional [str ] = None , md5 : Optional [str ] = None ):
207
187
"""Download a Google Drive file from and place it in root.
208
188
@@ -212,7 +192,12 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
212
192
filename (str, optional): Name to save the file under. If None, use the id of the file.
213
193
md5 (str, optional): MD5 checksum of the download. If None, do not check
214
194
"""
215
- # Based on https://stackoverflow.com/questions/38511444/python-download-files-from-google-drive-using-url
195
+ try :
196
+ import gdown
197
+ except ModuleNotFoundError :
198
+ raise RuntimeError (
199
+ "To download files from GDrive, 'gdown' is required. You can install it with 'pip install gdown'."
200
+ )
216
201
217
202
root = os .path .expanduser (root )
218
203
if not filename :
@@ -225,51 +210,10 @@ def download_file_from_google_drive(file_id: str, root: str, filename: Optional[
225
210
print (f"Using downloaded { 'and verified ' if md5 else '' } file: { fpath } " )
226
211
return
227
212
228
- url = "https://drive.google.com/uc"
229
- params = dict (id = file_id , export = "download" )
230
- with requests .Session () as session :
231
- response = session .get (url , params = params , stream = True )
213
+ gdown .download (id = file_id , output = fpath , quiet = False , user_agent = USER_AGENT )
232
214
233
- for key , value in response .cookies .items ():
234
- if key .startswith ("download_warning" ):
235
- token = value
236
- break
237
- else :
238
- api_response , content = _extract_gdrive_api_response (response )
239
- token = "t" if api_response == "Virus scan warning" else None
240
-
241
- if token is not None :
242
- response = session .get (url , params = dict (params , confirm = token ), stream = True )
243
- api_response , content = _extract_gdrive_api_response (response )
244
-
245
- if api_response == "Quota exceeded" :
246
- raise RuntimeError (
247
- f"The daily quota of the file { filename } is exceeded and it "
248
- f"can't be downloaded. This is a limitation of Google Drive "
249
- f"and can only be overcome by trying again later."
250
- )
251
-
252
- _save_response_content (content , fpath )
253
-
254
- # In case we deal with an unhandled GDrive API response, the file should be smaller than 10kB and contain only text
255
- if os .stat (fpath ).st_size < 10 * 1024 :
256
- with contextlib .suppress (UnicodeDecodeError ), open (fpath ) as fh :
257
- text = fh .read ()
258
- # Regular expression to detect HTML. Copied from https://stackoverflow.com/a/70585604
259
- if re .search (r"</?\s*[a-z-][^>]*\s*>|(&(?:[\w\d]+|#\d+|#x[a-f\d]+);)" , text ):
260
- warnings .warn (
261
- f"We detected some HTML elements in the downloaded file. "
262
- f"This most likely means that the download triggered an unhandled API response by GDrive. "
263
- f"Please report this to torchvision at https://github.com/pytorch/vision/issues including "
264
- f"the response:\n \n { text } "
265
- )
266
-
267
- if md5 and not check_md5 (fpath , md5 ):
268
- raise RuntimeError (
269
- f"The MD5 checksum of the download file { fpath } does not match the one on record."
270
- f"Please delete the file and try again. "
271
- f"If the issue persists, please report this to torchvision at https://github.com/pytorch/vision/issues."
272
- )
215
+ if not check_integrity (fpath , md5 ):
216
+ raise RuntimeError ("File not found or corrupted." )
273
217
274
218
275
219
def _extract_tar (from_path : str , to_path : str , compression : Optional [str ]) -> None :
0 commit comments