diff --git a/utils/dataloader.py b/utils/dataloader.py index d6386af..0d66e46 100644 --- a/utils/dataloader.py +++ b/utils/dataloader.py @@ -35,8 +35,8 @@ def __init__(self, db_path, ext): if self.ext == '.npy': self.loader = lambda x: np.load(x) else: - self.loader = lambda x: np.load(x)['x'] - # self.loader = lambda x: np.transpose(np.load(x)['x']) + # self.loader = lambda x: np.load(x)['x'] + self.loader = lambda x: np.transpose(np.load(x)['x']) if db_path.endswith('.pth'): # Assume a key,value dictionary self.db_type = 'pth' self.feat_file = torch.load(db_path) @@ -157,7 +157,7 @@ def __init__(self, opt): # self.fc_loader = HybridLoader(self.opt.input_fc_dir, '.npy') self.att_loader = HybridLoader(self.opt.image_feat_dir, '.npz') self.att2_loader = None - if len(self.opt.image_feat_dir2) != 0: + if hasattr(self.opt, 'image_feat_dir2') and len(self.opt.image_feat_dir2) != 0: self.att2_loader = HybridLoaderv2(self.opt.image_feat_dir, '.npz', self.opt.image_feat_dir2) # self.box_loader = HybridLoader(self.opt.input_att_dir, '.npz')[''] @@ -265,10 +265,10 @@ def collate_func(self, batch, split): max_att_len = max([i[0] for i in num_bbox_batch]) data['att_feats'] = np.zeros([len(att_batch)*seq_per_img, 100, att_batch[0].shape[1]], dtype = 'float32') for i in range(len(att_batch)): - data['att_feats'][i*5:(i+1)*5, :att_batch[i].shape[0]] = np.tile(att_batch[i], (5, 1)).reshape(5, att_batch[i].shape[0], att_batch[i].shape[1]) + data['att_feats'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = np.tile(att_batch[i], (seq_per_img, 1)).reshape(seq_per_img, att_batch[i].shape[0], att_batch[i].shape[1]) data['att_masks'] = np.zeros(data['att_feats'].shape[:2], dtype='float32') for i in range(len(att_batch)): - data['att_masks'][i*5:(i+1)*5, :att_batch[i].shape[0]] = np.ones([5, att_batch[i].shape[0]]) + data['att_masks'][i*seq_per_img:(i+1)*seq_per_img, :att_batch[i].shape[0]] = np.ones([seq_per_img, att_batch[i].shape[0]]) # set att_masks to None if attention features have same length if data['att_masks'].sum() == data['att_masks'].size: data['att_masks'] = None