1+ from typing import Tuple
12import torch
23import torch .nn as nn
34import zuko
45from lampe .inference import NPE , NRE
56
67
7- class Standardize (nn .Module ):
8- """
9- Module to standardize inputs and retransform them to the original space
10-
11- Args:
12- mean (torch.Tensor): mean of the data
13- std (torch.Tensor): standard deviation of the data
14-
15- Returns:
16- standardized (torch.Tensor): standardized data
17- """
18-
19- # Code adapted from :https://github.com/mackelab/sbi/blob/main/sbi/utils/sbiutils.py
20- def __init__ (self , mean : float , std : float ) -> None :
21- super (Standardize , self ).__init__ ()
22- mean , std = map (torch .as_tensor , (mean , std ))
23- self .mean = mean
24- self .std = std
25- self .register_buffer ("_mean" , mean )
26- self .register_buffer ("_std" , std )
27-
28- def forward (self , tensor : torch .Tensor ) -> torch .Tensor :
29- """
30- Standardize the input tensor
31-
32- Args:
33- tensor (torch.Tensor): input tensor
34-
35- Returns:
36- standardized (torch.Tensor): standardized tensor
37- """
38-
39- return (tensor - self ._mean ) / self ._std
40-
41- def transform (self , tensor : torch .Tensor ) -> torch .Tensor :
42- """
43- Transform the standardized tensor back to the original space
44-
45- Args:
46- tensor (torch.Tensor): input tensor
8+ class Classifier (nn .Module ):
9+ def __init__ (self , input_dim , out_dim , num_layers , nodes_per_layer , activation = nn .ReLU , dropout = 0.0 ):
10+ super ().__init__ ()
11+ self .classifier = nn .ModuleList ()
4712
48- Returns:
49- retransformed (torch.Tensor): retransformed tensor
50- """
13+ for i in range (num_layers ):
14+ if i == 0 :
15+ self .classifier .append (nn .Linear (input_dim , nodes_per_layer ))
16+ else :
17+ self .classifier .append (nn .Linear (nodes_per_layer , nodes_per_layer ))
18+ if dropout > 0.0 :
19+ self .classifier .append (nn .Dropout (dropout ))
20+ self .classifier .append (activation ())
21+
22+ self .classifier .append (nn .Linear (nodes_per_layer , out_dim ))
23+ self .classifier = nn .Sequential (* self .classifier )
5124
52- return (tensor * self ._std ) + self ._mean
25+ def forward (self , x ):
26+ return self .classifier (x )
5327
5428
55- class NPEWithEmbedding (nn .Module ):
56- """Neural Posterior Estimation with embedding net
29+ class ClassifierWithEmbedding (nn .Module ):
30+ """Classification with embedding net
5731
5832 Attributes:
59- npe (NPE): NPE model
33+ classifier: Classification model
6034 embedding (nn.Module): embedding net
61- standardize (Standardize): standardization module
6235 """
6336
6437 def __init__ (
6538 self ,
6639 embedding_net : nn .Module ,
6740 output_embedding_dim : int ,
68- num_transforms : int = 4 ,
69- num_hidden_flow : int = 2 ,
70- hidden_flow_dim : int = 128 ,
71- flow : nn .Module = zuko .flows .MAF ,
72- theta_shift : float = 0.0 ,
73- theta_scale : float = 1.0 ,
41+ num_classes : int = 2 ,
42+ num_layers : int = 5 ,
43+ nodes_per_layer : int = 128 ,
7444 ** kwargs ,
7545 ) -> None :
7646 """
@@ -93,55 +63,50 @@ def __init__(
9363
9464 super ().__init__ ()
9565
96- self .npe = NPE (
97- 1 ,
66+ self .classifier = Classifier (
9867 output_embedding_dim ,
99- transforms = num_transforms ,
100- build = flow ,
101- hidden_features = [ * [ hidden_flow_dim ] * num_hidden_flow , 128 , 64 ] ,
68+ num_classes ,
69+ num_layers ,
70+ nodes_per_layer ,
10271 ** kwargs ,
10372 )
104-
10573 self .embedding = embedding_net ()
106- self .standardize = Standardize (theta_shift , theta_scale )
10774
108- def forward (self , theta : torch . Tensor , x : torch .Tensor ) -> torch .Tensor :
75+ def forward (self , x : torch .Tensor ) -> torch .Tensor :
10976 """
110- Forward pass of the NPE model
111-
77+ Forward pass of the classifier model
11278 Args:
11379 theta (torch.Tensor): Conformational parameters.
11480 x (torch.Tensor): Image to condition the posterior on.
115-
11681 Returns:
117- torch.Tensor: Log probability of the posterior .
82+ torch.Tensor: unnormalized class probabilities .
11883 """
11984
120- return self .npe ( self . standardize ( theta ), self .embedding (x ))
121-
122- def flow (self , x : torch .Tensor ):
85+ return self .classifier ( self .embedding (x ))
86+
87+ def prob (self , x : torch .Tensor ) -> torch . Tensor :
12388 """
124- Conditions the posterior on an image .
125-
89+ Predict the class probabilities for the input data .
90+
12691 Args:
127- x (torch.Tensor): Image to condition the posterior on .
128-
92+ x (torch.Tensor): Input data .
93+
12994 Returns:
130- zuko.flows.Flow: The posterior distribution .
95+ torch.Tensor: Class probabilities .
13196 """
132- return self . npe . flow (self .embedding (x ))
133-
134- def sample (self , x : torch .Tensor , shape = ( 1 ,)) -> torch .Tensor :
97+ return torch . nn . functional . softmax (self .forward (x ))
98+
99+ def logits_embedding (self , x : torch .Tensor ) -> Tuple [ torch .Tensor , torch . Tensor ] :
135100 """
136- Generate samples from the posterior distribution .
101+ Get the logits from the classifier and the embedding from the embedding net .
137102
138103 Args:
139- x (torch.Tensor): Image to condition the posterior on.
140- shape (tuple, optional): Shape of the samples. Defaults to (1,).
104+ x (torch.Tensor): Input data.
141105
142106 Returns:
143- torch.Tensor: Samples from the posterior distribution.
107+ torch.Tensor: Logits from the classifier.
108+ torch.Tensor: Embedding from the embedding net.
144109 """
145-
146- samples_standardized = self .flow ( x ). sample ( shape )
147- return self . standardize . transform ( samples_standardized )
110+ embeddings = self . embedding ( x )
111+ logits = self .classifier ( embeddings )
112+ return logits , embeddings
0 commit comments