Open
Description
` def compute_nll(
self,
mean: torch.Tensor,
logvar: torch.Tensor,
labels: torch.Tensor,
) -> torch.Tensor:
"""Compute the loss as negative log likelihood.
Args:
mean: The mean prediction for labels.
logvar: The logvariance prediction for labels.
labels: The observed labels of the data.
Returns:
The negative log likelihood of each point.
"""
sqdiffs = (mean - labels) ** 2
return torch.exp(-logvar) * sqdiffs + logvar`
is n't this NLL loss is wrong here?
it should be return sqdiffs/torch.exp(logvar) + logvar
Metadata
Metadata
Assignees
Labels
No labels