Skip to content

Commit dc3cfc3

Browse files
committed
fixed forwarding of temperature scaling in classifier with embedding
1 parent 4ae94d7 commit dc3cfc3

File tree

1 file changed

+18
-12
lines changed

1 file changed

+18
-12
lines changed

src/cryo_sbi/inference/models/estimator_models.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -48,23 +48,30 @@ def __init__(
4848
self.classifier.append(nn.Linear(nodes_per_layer, num_classes))
4949
self.classifier = nn.Sequential(*self.classifier)
5050

51-
def forward(self, z):
52-
return self.classifier(z)
51+
def forward(self, z, tau=1.0):
52+
return self.classifier(z) / tau
5353

5454

5555
@add_classifier("PROTOTYPE")
5656
class PrototypeClassifier(BaseClassifier):
57-
def __init__(self, input_dim, num_classes):
57+
def __init__(self, input_dim, num_classes, noise_scale=0.0):
5858
super().__init__(input_dim, num_classes)
5959

6060
self.input_dim = input_dim
6161
self.num_classes = num_classes
62+
self.noise_scale = noise_scale
6263
self.prototypes = nn.Parameter(torch.randn(self.num_classes, self.input_dim))
6364

6465
def forward(self, z, tau=1.0):
66+
if self.training and self.noise_scale > 0.0:
67+
prototypes = self.prototypes + torch.randn_like(self.prototypes) * self.noise_scale
68+
else:
69+
prototypes = self.prototypes
70+
6571
z2 = (z**2).sum(dim=1, keepdim=True)
66-
p2 = (self.prototypes**2).sum(dim=1).unsqueeze(0)
67-
logits = -(z2 + p2 - 2 * z @ self.prototypes.T) / tau
72+
p2 = (prototypes**2).sum(dim=1).unsqueeze(0)
73+
logits = -(z2 + p2 - 2 * z @ prototypes.T) / tau
74+
6875
return logits
6976

7077

@@ -78,13 +85,12 @@ def __init__(
7885
self.classifier = classifier()
7986
self.embedding = embedding_net()
8087

81-
def forward(self, x: torch.Tensor) -> torch.Tensor:
82-
return self.classifier(self.embedding(x))
83-
84-
def probs(self, x: torch.Tensor) -> torch.Tensor:
85-
return torch.nn.functional.softmax(self.forward(x), dim=1)
88+
def forward(self, x: torch.Tensor, tau=1.0) -> torch.Tensor:
89+
return self.classifier(self.embedding(x), tau=tau)
8690

87-
def logits_embedding(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
91+
def probs(self, x: torch.Tensor, tau=1.0) -> torch.Tensor:
92+
return torch.nn.functional.softmax(self.forward(x, tau=tau) / tau, dim=1)
93+
def logits_embedding(self, x: torch.Tensor, tau=1.0) -> Tuple[torch.Tensor, torch.Tensor]:
8894
embeddings = self.embedding(x)
89-
logits = self.classifier(embeddings)
95+
logits = self.classifier(embeddings, tau=tau)
9096
return logits, embeddings

0 commit comments

Comments
 (0)