@@ -48,23 +48,30 @@ def __init__(
4848 self .classifier .append (nn .Linear (nodes_per_layer , num_classes ))
4949 self .classifier = nn .Sequential (* self .classifier )
5050
51- def forward (self , z ):
52- return self .classifier (z )
51+ def forward (self , z , tau = 1.0 ):
52+ return self .classifier (z ) / tau
5353
5454
5555@add_classifier ("PROTOTYPE" )
5656class PrototypeClassifier (BaseClassifier ):
57- def __init__ (self , input_dim , num_classes ):
57+ def __init__ (self , input_dim , num_classes , noise_scale = 0.0 ):
5858 super ().__init__ (input_dim , num_classes )
5959
6060 self .input_dim = input_dim
6161 self .num_classes = num_classes
62+ self .noise_scale = noise_scale
6263 self .prototypes = nn .Parameter (torch .randn (self .num_classes , self .input_dim ))
6364
6465 def forward (self , z , tau = 1.0 ):
66+ if self .training and self .noise_scale > 0.0 :
67+ prototypes = self .prototypes + torch .randn_like (self .prototypes ) * self .noise_scale
68+ else :
69+ prototypes = self .prototypes
70+
6571 z2 = (z ** 2 ).sum (dim = 1 , keepdim = True )
66- p2 = (self .prototypes ** 2 ).sum (dim = 1 ).unsqueeze (0 )
67- logits = - (z2 + p2 - 2 * z @ self .prototypes .T ) / tau
72+ p2 = (prototypes ** 2 ).sum (dim = 1 ).unsqueeze (0 )
73+ logits = - (z2 + p2 - 2 * z @ prototypes .T ) / tau
74+
6875 return logits
6976
7077
@@ -78,13 +85,12 @@ def __init__(
7885 self .classifier = classifier ()
7986 self .embedding = embedding_net ()
8087
81- def forward (self , x : torch .Tensor ) -> torch .Tensor :
82- return self .classifier (self .embedding (x ))
83-
84- def probs (self , x : torch .Tensor ) -> torch .Tensor :
85- return torch .nn .functional .softmax (self .forward (x ), dim = 1 )
88+ def forward (self , x : torch .Tensor , tau = 1.0 ) -> torch .Tensor :
89+ return self .classifier (self .embedding (x ), tau = tau )
8690
87- def logits_embedding (self , x : torch .Tensor ) -> Tuple [torch .Tensor , torch .Tensor ]:
91+ def probs (self , x : torch .Tensor , tau = 1.0 ) -> torch .Tensor :
92+ return torch .nn .functional .softmax (self .forward (x , tau = tau ) / tau , dim = 1 )
93+ def logits_embedding (self , x : torch .Tensor , tau = 1.0 ) -> Tuple [torch .Tensor , torch .Tensor ]:
8894 embeddings = self .embedding (x )
89- logits = self .classifier (embeddings )
95+ logits = self .classifier (embeddings , tau = tau )
9096 return logits , embeddings
0 commit comments