-
Notifications
You must be signed in to change notification settings - Fork 176
Open
Description
The function one_hot_from_names throws an AssertionError when a class name - which is not in the original ImageNet classes and for which possible synsets do not exist either - is used.
This happens because the batch_size is not updated when calling one_hot_from_int in utils.py after converting words to their respective indices.
The following lines should be able to reproduce this:
import torch
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names
model = BigGAN.from_pretrained('biggan-deep-256')
class_vector = one_hot_from_names(['cake'], batch_size=1)This would throw the following error:
---------------------------------------------------------------------------
AssertionError Traceback (most recent call last)
<ipython-input-4-4dd4cd1296e1> in <module>()
1
----> 2 class_vector = one_hot_from_names(['cake'], batch_size=1)
/home/saqib/Projects/Poem2GIF/repos/pytorch-pretrained-BigGAN/pytorch_pretrained_biggan/utils.py in one_hot_from_names(class_name_or_list, batch_size)
211 classes.append(IMAGENET[possible_synsets[0].offset()])
212
--> 213 return one_hot_from_int(classes, batch_size=batch_size)
214
215
/home/saqib/Projects/Poem2GIF/repos/pytorch-pretrained-BigGAN/pytorch_pretrained_biggan/utils.py in one_hot_from_int(int_or_list, batch_size)
164 int_or_list = [int_or_list[0]] * batch_size
165
--> 166 assert batch_size == len(int_or_list)
167
168 array = np.zeros((batch_size, NUM_CLASSES), dtype=np.float32)
AssertionError:
arslanan
Metadata
Metadata
Assignees
Labels
No labels