Skip to content

AssertionError when working with classes which are not in ImageNet #16

@saqibns

Description

@saqibns

The function one_hot_from_names throws an AssertionError when a class name - which is not in the original ImageNet classes and for which possible synsets do not exist either - is used.

This happens because the batch_size is not updated when calling one_hot_from_int in utils.py after converting words to their respective indices.

The following lines should be able to reproduce this:

import torch
from pytorch_pretrained_biggan import BigGAN, one_hot_from_names
model = BigGAN.from_pretrained('biggan-deep-256')
class_vector = one_hot_from_names(['cake'], batch_size=1)

This would throw the following error:

---------------------------------------------------------------------------
AssertionError                            Traceback (most recent call last)
<ipython-input-4-4dd4cd1296e1> in <module>()
      1 
----> 2 class_vector = one_hot_from_names(['cake'], batch_size=1)

/home/saqib/Projects/Poem2GIF/repos/pytorch-pretrained-BigGAN/pytorch_pretrained_biggan/utils.py in one_hot_from_names(class_name_or_list, batch_size)
    211                 classes.append(IMAGENET[possible_synsets[0].offset()])
    212 
--> 213     return one_hot_from_int(classes, batch_size=batch_size)
    214 
    215 

/home/saqib/Projects/Poem2GIF/repos/pytorch-pretrained-BigGAN/pytorch_pretrained_biggan/utils.py in one_hot_from_int(int_or_list, batch_size)
    164         int_or_list = [int_or_list[0]] * batch_size
    165 
--> 166     assert batch_size == len(int_or_list)
    167 
    168     array = np.zeros((batch_size, NUM_CLASSES), dtype=np.float32)

AssertionError: 

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions