Skip to content

Commit ef768ad

Browse files
fmassasoumith
authored andcommitted
Fix for no tqdm (#770)
* Fixes for PyTorch version of tqdm * Flake * Flake fix
1 parent d4dec6b commit ef768ad

File tree

3 files changed

+33
-5
lines changed

3 files changed

+33
-5
lines changed

test/preprocess-bench.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,10 @@
4646

4747
start_time = timer()
4848
batch_count = 20 * args.nThreads
49-
for _ in tqdm(range(batch_count)):
50-
batch = next(train_iter)
49+
with tqdm(total=batch_count) as pbar:
50+
for _ in tqdm(range(batch_count)):
51+
pbar.update(1)
52+
batch = next(train_iter)
5153
end_time = timer()
5254
print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch,"
5355
" {image:.2f} ms/image {rate:.0f} images/sec"

test/test_datasets_utils.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
import shutil
2+
import tempfile
3+
import torch
4+
import torchvision.datasets.utils as utils
5+
import unittest
6+
7+
8+
class Tester(unittest.TestCase):
9+
10+
def test_download_url(self):
11+
temp_dir = tempfile.mkdtemp()
12+
url = "http://github.com/pytorch/vision/archive/master.zip"
13+
utils.download_url(url, temp_dir)
14+
shutil.rmtree(temp_dir)
15+
16+
def test_download_url_retry_http(self):
17+
temp_dir = tempfile.mkdtemp()
18+
url = "https://github.com/pytorch/vision/archive/master.zip"
19+
utils.download_url(url, temp_dir)
20+
shutil.rmtree(temp_dir)
21+
22+
23+
if __name__ == '__main__':
24+
unittest.main()

torchvision/datasets/utils.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,9 @@
55
from torch.utils.model_zoo import tqdm
66

77

8-
def gen_bar_updater(pbar):
8+
def gen_bar_updater():
9+
pbar = tqdm(total=None)
10+
911
def bar_update(count, block_size, total_size):
1012
if pbar.total is None and total_size:
1113
pbar.total = total_size
@@ -70,7 +72,7 @@ def download_url(url, root, filename=None, md5=None):
7072
print('Downloading ' + url + ' to ' + fpath)
7173
urllib.request.urlretrieve(
7274
url, fpath,
73-
reporthook=gen_bar_updater(tqdm())
75+
reporthook=gen_bar_updater()
7476
)
7577
except OSError:
7678
if url[:5] == 'https':
@@ -79,7 +81,7 @@ def download_url(url, root, filename=None, md5=None):
7981
' Downloading ' + url + ' to ' + fpath)
8082
urllib.request.urlretrieve(
8183
url, fpath,
82-
reporthook=gen_bar_updater(tqdm())
84+
reporthook=gen_bar_updater()
8385
)
8486

8587

0 commit comments

Comments
 (0)