Open
Description
hi guys, nice work! However, maybe you forget to make torch.eye to device in one_hot_embedding
?
Modified one_hot_embedding
in helpers.py
as
def one_hot_embedding(labels, num_classes=10):
# Convert to One Hot Encoding
device = get_device()
y = torch.eye(num_classes).to(device)
return y[labels]
Metadata
Metadata
Assignees
Labels
No labels