Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
108 changes: 56 additions & 52 deletions downloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,16 +12,18 @@

from requests.exceptions import ConnectionError, ReadTimeout, TooManyRedirects, MissingSchema, InvalidURL

from functools import partial

parser = argparse.ArgumentParser(description='ImageNet image scraper')
parser.add_argument('-scrape_only_flickr', default=True, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('-number_of_classes', default = 10, type=int)
parser.add_argument('-images_per_class', default = 10, type=int)
parser.add_argument('-data_root', default='' , type=str)
parser.add_argument('-use_class_list', default=False,type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('-number_of_classes', default=10, type=int)
parser.add_argument('-images_per_class', default=10, type=int)
parser.add_argument('-data_root', default='', type=str)
parser.add_argument('-use_class_list', default=False, type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('-class_list', default=[], nargs='*')
parser.add_argument('-debug', default=False,type=lambda x: (str(x).lower() == 'true'))
parser.add_argument('-debug', default=False, type=lambda x: (str(x).lower() == 'true'))

parser.add_argument('-multiprocessing_workers', default = 8, type=int)
parser.add_argument('-multiprocessing_workers', default=8, type=int)

args, args_other = parser.parse_known_args()

Expand All @@ -36,7 +38,6 @@
logging.error(f'folder {args.data_root} does not exist! please provide existing folder in -data_root arg!')
exit()


IMAGENET_API_WNID_TO_URLS = lambda wnid: f'http://www.image-net.org/api/imagenet.synset.geturls?wnid={wnid}'

current_folder = os.path.dirname(os.path.realpath(__file__))
Expand All @@ -52,11 +53,11 @@
classes_to_scrape = []

if args.use_class_list == True:
for item in args.class_list:
classes_to_scrape.append(item)
if item not in class_info_dict:
logging.error(f'Class {item} not found in ImageNete')
exit()
for item in args.class_list:
classes_to_scrape.append(item)
if item not in class_info_dict:
logging.error(f'Class {item} not found in ImageNete')
exit()

elif args.use_class_list == False:
potential_class_pool = []
Expand All @@ -70,24 +71,23 @@
potential_class_pool.append(key)

if (len(potential_class_pool) < args.number_of_classes):
logging.error(f"With {args.images_per_class} images per class there are {len(potential_class_pool)} to choose from.")
logging.error(
f"With {args.images_per_class} images per class there are {len(potential_class_pool)} to choose from.")
logging.error(f"Decrease number of classes or decrease images per class.")
exit()

picked_classes_idxes = np.random.choice(len(potential_class_pool), args.number_of_classes, replace = False)
picked_classes_idxes = np.random.choice(len(potential_class_pool), args.number_of_classes, replace=False)

for idx in picked_classes_idxes:
classes_to_scrape.append(potential_class_pool[idx])


print("Picked the following clases:")
print([ class_info_dict[class_wnid]['class_name'] for class_wnid in classes_to_scrape ])
print([class_info_dict[class_wnid]['class_name'] for class_wnid in classes_to_scrape])

imagenet_images_folder = os.path.join(args.data_root, 'imagenet_images')
if not os.path.isdir(imagenet_images_folder):
os.mkdir(imagenet_images_folder)


scraping_stats = dict(
all=dict(
tried=0,
Expand All @@ -106,33 +106,35 @@
)
)


def add_debug_csv_row(row):
with open('stats.csv', "a") as csv_f:
csv_writer = csv.writer(csv_f, delimiter=",")
csv_writer.writerow(row)


class MultiStats():
def __init__(self):

self.lock = Lock()

self.stats = dict(
all=dict(
tried=Value('d', 0),
success=Value('d',0),
time_spent=Value('d',0),
success=Value('d', 0),
time_spent=Value('d', 0),
),
is_flickr=dict(
tried=Value('d', 0),
success=Value('d',0),
time_spent=Value('d',0),
success=Value('d', 0),
time_spent=Value('d', 0),
),
not_flickr=dict(
tried=Value('d', 0),
success=Value('d', 0),
time_spent=Value('d', 0),
)
)

def inc(self, cls, stat, val):
with self.lock:
self.stats[cls][stat].value += val
Expand All @@ -142,8 +144,8 @@ def get(self, cls, stat):
ret = self.stats[cls][stat].value
return ret

multi_stats = MultiStats()

multi_stats = MultiStats()

if args.debug:
row = [
Expand All @@ -159,6 +161,7 @@ def get(self, cls, stat):
]
add_debug_csv_row(row)


def add_stats_to_debug_csv():
row = [
multi_stats.get('all', 'tried'),
Expand All @@ -173,8 +176,8 @@ def add_stats_to_debug_csv():
]
add_debug_csv_row(row)

def print_stats(cls, print_func):

def print_stats(cls, print_func):
actual_all_time_spent = time.time() - scraping_t_start.value
processes_all_time_spent = multi_stats.get('all', 'time_spent')

Expand All @@ -183,17 +186,18 @@ def print_stats(cls, print_func):
else:
actual_processes_ratio = actual_all_time_spent / processes_all_time_spent

#print(f"actual all time: {actual_all_time_spent} proc all time {processes_all_time_spent}")
# print(f"actual all time: {actual_all_time_spent} proc all time {processes_all_time_spent}")

print_func(f'STATS For class {cls}:')
print_func(f' tried {multi_stats.get(cls, "tried")} urls with'
f' {multi_stats.get(cls, "success")} successes')

if multi_stats.get(cls, "tried") > 0:
print_func(f'{100.0 * multi_stats.get(cls, "success")/multi_stats.get(cls, "tried")}% success rate for {cls} urls ')
print_func(
f'{100.0 * multi_stats.get(cls, "success") / multi_stats.get(cls, "tried")}% success rate for {cls} urls ')
if multi_stats.get(cls, "success") > 0:
print_func(f'{multi_stats.get(cls,"time_spent") * actual_processes_ratio / multi_stats.get(cls,"success")} seconds spent per {cls} succesful image download')

print_func(
f'{multi_stats.get(cls, "time_spent") * actual_processes_ratio / multi_stats.get(cls, "success")} seconds spent per {cls} succesful image download')


lock = Lock()
Expand All @@ -202,16 +206,15 @@ def print_stats(cls, print_func):
class_folder = ''
class_images = Value('d', 0)

def get_image(img_url):

#print(f'Processing {img_url}')
def get_image(class_folder, img_url):
# print(f'Processing {img_url}')

#time.sleep(3)
# time.sleep(3)

if len(img_url) <= 1:
return


cls_imgs = 0
with lock:
cls_imgs = class_images.value
Expand All @@ -237,11 +240,11 @@ def finish(status):
multi_stats.inc(cls, 'time_spent', t_spent)
multi_stats.inc('all', 'time_spent', t_spent)

multi_stats.inc(cls,'tried', 1)
multi_stats.inc(cls, 'tried', 1)
multi_stats.inc('all', 'tried', 1)

if status == 'success':
multi_stats.inc(cls,'success', 1)
multi_stats.inc(cls, 'success', 1)
multi_stats.inc('all', 'success', 1)

elif status == 'failure':
Expand All @@ -251,7 +254,6 @@ def finish(status):
exit()
return


with lock:
url_tries.value += 1
if url_tries.value % 250 == 0:
Expand All @@ -263,7 +265,7 @@ def finish(status):
add_stats_to_debug_csv()

try:
img_resp = requests.get(img_url, timeout = 1)
img_resp = requests.get(img_url, timeout=1)
except ConnectionError:
logging.debug(f"Connection Error for url {img_url}")
return finish('failure')
Expand Down Expand Up @@ -314,26 +316,28 @@ def finish(status):
return finish('success')


for class_wnid in classes_to_scrape:
if __name__ == '__main__':

class_name = class_info_dict[class_wnid]["class_name"]
print(f'Scraping images for class \"{class_name}\"')
url_urls = IMAGENET_API_WNID_TO_URLS(class_wnid)
for class_wnid in classes_to_scrape:

time.sleep(0.05)
resp = requests.get(url_urls)
class_name = class_info_dict[class_wnid]["class_name"]
print(f'Scraping images for class \"{class_name}\"')
url_urls = IMAGENET_API_WNID_TO_URLS(class_wnid)

class_folder = os.path.join(imagenet_images_folder, class_name)
if not os.path.exists(class_folder):
os.mkdir(class_folder)
time.sleep(0.05)
resp = requests.get(url_urls)

class_images.value = 0
class_folder = os.path.join(imagenet_images_folder, class_name)
if not os.path.exists(class_folder):
os.mkdir(class_folder)

urls = [url.decode('utf-8') for url in resp.content.splitlines()]
class_images.value = 0

#for url in urls:
# get_image(url)
urls = [url.decode('utf-8') for url in resp.content.splitlines()]

print(f"Multiprocessing workers: {args.multiprocessing_workers}")
with Pool(processes=args.multiprocessing_workers) as p:
p.map(get_image,urls)
# for url in urls:
# get_image(url)
part = partial(get_image, class_folder)
print(f"Multiprocessing workers: {args.multiprocessing_workers}")
with Pool(processes=args.multiprocessing_workers) as p:
p.map(part, urls)