Open
Description
Hello,
I've found that the message can be negative at the inference time if we set comm_narrow=True.
However, during training, the message is always passed through sigmoid() which returns only positive value.
The part that I refer to is in dru.py. Regularise mode always returns a positive value but discretise mode can return a negative value. Specifically, return (m.gt(0.5).float() - 0.5).sign().float()
Is this part mistaken? or I may misunderstand something.
class DRU:
def __init__(self, sigma, comm_narrow=True, hard=False):
self.sigma = sigma
self.comm_narrow = comm_narrow
self.hard = hard
def regularize(self, m):
m_reg = m + torch.randn(m.size()) * self.sigma
if self.comm_narrow:
m_reg = torch.sigmoid(m_reg)
else:
m_reg = torch.softmax(m_reg, 0)
return m_reg
def discretize(self, m):
if self.hard:
if self.comm_narrow:
return (m.gt(0.5).float() - 0.5).sign().float()
else:
m_ = torch.zeros_like(m)
if m.dim() == 1:
_, idx = m.max(0)
m_[idx] = 1.
elif m.dim() == 2:
_, idx = m.max(1)
for b in range(idx.size(0)):
m_[b, idx[b]] = 1.
else:
raise ValueError('Wrong message shape: {}'.format(m.size()))
return m_
else:
scale = 2 * 20
if self.comm_narrow:
return torch.sigmoid((m.gt(0.5).float() - 0.5) * scale)
else:
return torch.softmax(m * scale, -1)
def forward(self, m, train_mode):
if train_mode:
return self.regularize(m)
else:
return self.discretize(m)
Metadata
Metadata
Assignees
Labels
No labels
Activity