Skip to content

Commit 7eb1447

Browse files
committed
classifier
1 parent d145cf2 commit 7eb1447

File tree

10 files changed

+190
-154
lines changed

10 files changed

+190
-154
lines changed

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,5 @@ dependencies = [
2828

2929
[project.scripts]
3030
train_npe_model = "cryo_sbi.inference.command_line_tools:cl_npe_train_no_saving"
31+
train_classifier = "cryo_sbi.inference.command_line_tools:cl_classifier_train_no_saving"
3132
model_to_tensor = "cryo_sbi.utils.command_line_tools:cl_models_to_tensor"

src/cryo_sbi/inference/command_line_tools.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import argparse
2-
from cryo_sbi.inference.train_npe_model import (
3-
npe_train_no_saving,
2+
from cryo_sbi.inference.train_models import (
3+
train_classifier
44
)
55

66

7-
def cl_npe_train_no_saving():
7+
def cl_classifier_train_no_saving():
88
cl_parser = argparse.ArgumentParser()
99

1010
cl_parser.add_argument(
@@ -47,7 +47,7 @@ def cl_npe_train_no_saving():
4747

4848
args = cl_parser.parse_args()
4949

50-
npe_train_no_saving(
50+
train_classifier(
5151
image_config=args.image_config_file,
5252
train_config=args.train_config_file,
5353
epochs=args.epochs,
@@ -59,4 +59,4 @@ def cl_npe_train_no_saving():
5959
device=args.train_device,
6060
saving_frequency=args.saving_freq,
6161
simulation_batch_size=args.simulation_batch_size,
62-
)
62+
)

src/cryo_sbi/inference/models/build_models.py

Lines changed: 9 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from cryo_sbi.inference.models.embedding_nets import EMBEDDING_NETS
77

88

9-
def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
9+
def build_classifier(config: dict, **embedding_kwargs) -> nn.Module:
1010
"""
1111
Function to build NPE estimator with embedding net
1212
from config_file
@@ -19,17 +19,6 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
1919
estimator (nn.Module): NPE estimator
2020
"""
2121

22-
if config["MODEL"] == "MAF":
23-
model = zuko.flows.MAF
24-
elif config["MODEL"] == "NSF":
25-
model = zuko.flows.NSF
26-
elif config["MODEL"] == "SOSPF":
27-
model = zuko.flows.SOSPF
28-
else:
29-
raise NotImplementedError(
30-
f"Model : {config['MODEL']} has not been implemented yet!"
31-
)
32-
3322
try:
3423
embedding = partial(
3524
EMBEDDING_NETS[config["EMBEDDING"]], config["OUT_DIM"], **embedding_kwargs
@@ -40,20 +29,16 @@ def build_npe_flow_model(config: dict, **embedding_kwargs) -> nn.Module:
4029
The following embeddings are implemented : {[key for key in EMBEDDING_NETS.keys()]}"
4130
)
4231

43-
estimator = estimator_models.NPEWithEmbedding(
32+
estimator = estimator_models.ClassifierWithEmbedding(
4433
embedding_net=embedding,
4534
output_embedding_dim=config["OUT_DIM"],
46-
num_transforms=config["NUM_TRANSFORM"],
47-
num_hidden_flow=config["NUM_HIDDEN_FLOW"],
48-
hidden_flow_dim=config["HIDDEN_DIM_FLOW"],
49-
flow=model,
50-
theta_shift=config["THETA_SHIFT"],
51-
theta_scale=config["THETA_SCALE"],
52-
**{"activation": partial(nn.LeakyReLU, 0.1)},
35+
num_classes=config["NUM_CLASSES"],
36+
num_layers=config["NUM_LAYERS"],
37+
nodes_per_layer=config["NODES_PER_LAYER"],
38+
**{
39+
"activation": partial(nn.LeakyReLU, 0.1),
40+
"dropout": config["DROPOUT"],
41+
},
5342
)
5443

5544
return estimator
56-
57-
58-
def build_nre_classifier_model(config: dict, **embedding_kwargs) -> nn.Module:
59-
raise NotImplementedError("NRE classifier model has not been implemented yet!")
Lines changed: 49 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -1,76 +1,46 @@
1+
from typing import Tuple
12
import torch
23
import torch.nn as nn
34
import zuko
45
from lampe.inference import NPE, NRE
56

67

7-
class Standardize(nn.Module):
8-
"""
9-
Module to standardize inputs and retransform them to the original space
10-
11-
Args:
12-
mean (torch.Tensor): mean of the data
13-
std (torch.Tensor): standard deviation of the data
14-
15-
Returns:
16-
standardized (torch.Tensor): standardized data
17-
"""
18-
19-
# Code adapted from :https://github.com/mackelab/sbi/blob/main/sbi/utils/sbiutils.py
20-
def __init__(self, mean: float, std: float) -> None:
21-
super(Standardize, self).__init__()
22-
mean, std = map(torch.as_tensor, (mean, std))
23-
self.mean = mean
24-
self.std = std
25-
self.register_buffer("_mean", mean)
26-
self.register_buffer("_std", std)
27-
28-
def forward(self, tensor: torch.Tensor) -> torch.Tensor:
29-
"""
30-
Standardize the input tensor
31-
32-
Args:
33-
tensor (torch.Tensor): input tensor
34-
35-
Returns:
36-
standardized (torch.Tensor): standardized tensor
37-
"""
38-
39-
return (tensor - self._mean) / self._std
40-
41-
def transform(self, tensor: torch.Tensor) -> torch.Tensor:
42-
"""
43-
Transform the standardized tensor back to the original space
44-
45-
Args:
46-
tensor (torch.Tensor): input tensor
8+
class Classifier(nn.Module):
9+
def __init__(self, input_dim, out_dim, num_layers, nodes_per_layer, activation=nn.ReLU, dropout=0.0):
10+
super().__init__()
11+
self.classifier = nn.ModuleList()
4712

48-
Returns:
49-
retransformed (torch.Tensor): retransformed tensor
50-
"""
13+
for i in range(num_layers):
14+
if i == 0:
15+
self.classifier.append(nn.Linear(input_dim, nodes_per_layer))
16+
else:
17+
self.classifier.append(nn.Linear(nodes_per_layer, nodes_per_layer))
18+
if dropout > 0.0:
19+
self.classifier.append(nn.Dropout(dropout))
20+
self.classifier.append(activation())
21+
22+
self.classifier.append(nn.Linear(nodes_per_layer, out_dim))
23+
self.classifier = nn.Sequential(*self.classifier)
5124

52-
return (tensor * self._std) + self._mean
25+
def forward(self, x):
26+
return self.classifier(x)
5327

5428

55-
class NPEWithEmbedding(nn.Module):
56-
"""Neural Posterior Estimation with embedding net
29+
class ClassifierWithEmbedding(nn.Module):
30+
"""Classification with embedding net
5731
5832
Attributes:
59-
npe (NPE): NPE model
33+
classifier: Classification model
6034
embedding (nn.Module): embedding net
61-
standardize (Standardize): standardization module
6235
"""
6336

6437
def __init__(
6538
self,
6639
embedding_net: nn.Module,
6740
output_embedding_dim: int,
68-
num_transforms: int = 4,
69-
num_hidden_flow: int = 2,
70-
hidden_flow_dim: int = 128,
71-
flow: nn.Module = zuko.flows.MAF,
72-
theta_shift: float = 0.0,
73-
theta_scale: float = 1.0,
41+
num_classes: int = 2,
42+
num_layers: int = 5,
43+
nodes_per_layer: int = 128,
7444
**kwargs,
7545
) -> None:
7646
"""
@@ -93,55 +63,50 @@ def __init__(
9363

9464
super().__init__()
9565

96-
self.npe = NPE(
97-
1,
66+
self.classifier = Classifier(
9867
output_embedding_dim,
99-
transforms=num_transforms,
100-
build=flow,
101-
hidden_features=[*[hidden_flow_dim] * num_hidden_flow, 128, 64],
68+
num_classes,
69+
num_layers,
70+
nodes_per_layer,
10271
**kwargs,
10372
)
104-
10573
self.embedding = embedding_net()
106-
self.standardize = Standardize(theta_shift, theta_scale)
10774

108-
def forward(self, theta: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
75+
def forward(self, x: torch.Tensor) -> torch.Tensor:
10976
"""
110-
Forward pass of the NPE model
111-
77+
Forward pass of the classifier model
11278
Args:
11379
theta (torch.Tensor): Conformational parameters.
11480
x (torch.Tensor): Image to condition the posterior on.
115-
11681
Returns:
117-
torch.Tensor: Log probability of the posterior.
82+
torch.Tensor: unnormalized class probabilities.
11883
"""
11984

120-
return self.npe(self.standardize(theta), self.embedding(x))
121-
122-
def flow(self, x: torch.Tensor):
85+
return self.classifier(self.embedding(x))
86+
87+
def prob(self, x: torch.Tensor) -> torch.Tensor:
12388
"""
124-
Conditions the posterior on an image.
125-
89+
Predict the class probabilities for the input data.
90+
12691
Args:
127-
x (torch.Tensor): Image to condition the posterior on.
128-
92+
x (torch.Tensor): Input data.
93+
12994
Returns:
130-
zuko.flows.Flow: The posterior distribution.
95+
torch.Tensor: Class probabilities.
13196
"""
132-
return self.npe.flow(self.embedding(x))
133-
134-
def sample(self, x: torch.Tensor, shape=(1,)) -> torch.Tensor:
97+
return torch.nn.functional.softmax(self.forward(x))
98+
99+
def logits_embedding(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
135100
"""
136-
Generate samples from the posterior distribution.
101+
Get the logits from the classifier and the embedding from the embedding net.
137102
138103
Args:
139-
x (torch.Tensor): Image to condition the posterior on.
140-
shape (tuple, optional): Shape of the samples. Defaults to (1,).
104+
x (torch.Tensor): Input data.
141105
142106
Returns:
143-
torch.Tensor: Samples from the posterior distribution.
107+
torch.Tensor: Logits from the classifier.
108+
torch.Tensor: Embedding from the embedding net.
144109
"""
145-
146-
samples_standardized = self.flow(x).sample(shape)
147-
return self.standardize.transform(samples_standardized)
110+
embeddings = self.embedding(x)
111+
logits = self.classifier(embeddings)
112+
return logits, embeddings

src/cryo_sbi/inference/priors.py

Lines changed: 26 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,31 @@ def gen_quat() -> torch.Tensor:
2323
return quat
2424

2525

26+
class IndexPrior():
27+
def __init__(self, max_index: int, device="cpu") -> None:
28+
self.max_index = max_index
29+
self.device = device
30+
31+
self.index_prior = torch.distributions.Categorical(
32+
probs=torch.tensor(
33+
[1 / (max_index + 1) for _ in range(max_index + 1)],
34+
device=device,
35+
)
36+
)
37+
38+
def sample(self, shape) -> torch.Tensor:
39+
"""
40+
Sample indices from the prior distribution.
41+
42+
Args:
43+
shape (tuple): Shape of the samples to be generated.
44+
45+
Returns:
46+
torch.Tensor: Sampled indices.
47+
"""
48+
return self.index_prior.sample(shape)
49+
50+
2651
def get_image_priors(
2752
max_index, image_config: dict, device="cuda"
2853
) -> zuko.distributions.BoxUniform:
@@ -110,10 +135,7 @@ def get_image_priors(
110135
ndims=1,
111136
)
112137

113-
index_prior = zuko.distributions.BoxUniform(
114-
lower=torch.tensor([0], dtype=torch.float32, device=device),
115-
upper=torch.tensor([max_index], dtype=torch.float32, device=device),
116-
)
138+
index_prior = IndexPrior(max_index, device)
117139
quaternion_prior = QuaternionPrior(device)
118140
if (
119141
image_config.get("ROTATIONS")
@@ -132,7 +154,6 @@ def get_image_priors(
132154
b_factor_prior,
133155
amp_prior,
134156
snr_prior,
135-
device=device,
136157
)
137158

138159

@@ -168,7 +189,6 @@ def __init__(
168189
b_factor_prior,
169190
amp_prior,
170191
snr_prior,
171-
device,
172192
) -> None:
173193
self.priors = [
174194
index_prior,

0 commit comments

Comments
 (0)