Skip to content

Commit

Permalink
Fix for no tqdm (#770)
Browse files Browse the repository at this point in the history
* Fixes for PyTorch version of tqdm

* Flake

* Flake fix
  • Loading branch information
fmassa authored and soumith committed Mar 1, 2019
1 parent d4dec6b commit ef768ad
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 5 deletions.
6 changes: 4 additions & 2 deletions test/preprocess-bench.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,10 @@

start_time = timer()
batch_count = 20 * args.nThreads
for _ in tqdm(range(batch_count)):
batch = next(train_iter)
with tqdm(total=batch_count) as pbar:
for _ in tqdm(range(batch_count)):
pbar.update(1)
batch = next(train_iter)
end_time = timer()
print("Performance: {dataset:.0f} minutes/dataset, {batch:.1f} ms/batch,"
" {image:.2f} ms/image {rate:.0f} images/sec"
Expand Down
24 changes: 24 additions & 0 deletions test/test_datasets_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import shutil
import tempfile
import torch
import torchvision.datasets.utils as utils
import unittest


class Tester(unittest.TestCase):

def test_download_url(self):
temp_dir = tempfile.mkdtemp()
url = "http://github.com/pytorch/vision/archive/master.zip"
utils.download_url(url, temp_dir)
shutil.rmtree(temp_dir)

def test_download_url_retry_http(self):
temp_dir = tempfile.mkdtemp()
url = "https://github.com/pytorch/vision/archive/master.zip"
utils.download_url(url, temp_dir)
shutil.rmtree(temp_dir)


if __name__ == '__main__':
unittest.main()
8 changes: 5 additions & 3 deletions torchvision/datasets/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,9 @@
from torch.utils.model_zoo import tqdm


def gen_bar_updater(pbar):
def gen_bar_updater():
pbar = tqdm(total=None)

def bar_update(count, block_size, total_size):
if pbar.total is None and total_size:
pbar.total = total_size
Expand Down Expand Up @@ -70,7 +72,7 @@ def download_url(url, root, filename=None, md5=None):
print('Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater(tqdm())
reporthook=gen_bar_updater()
)
except OSError:
if url[:5] == 'https':
Expand All @@ -79,7 +81,7 @@ def download_url(url, root, filename=None, md5=None):
' Downloading ' + url + ' to ' + fpath)
urllib.request.urlretrieve(
url, fpath,
reporthook=gen_bar_updater(tqdm())
reporthook=gen_bar_updater()
)


Expand Down

0 comments on commit ef768ad

Please sign in to comment.