Skip to content

Commit a6b7ead

Browse files
Merge branch 'adversarial'
Merge adversarial branch into main
2 parents 0481d3c + fc1e296 commit a6b7ead

File tree

18 files changed

+1468
-60
lines changed

18 files changed

+1468
-60
lines changed

.gitignore

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,3 +170,5 @@ cython_debug/
170170
experiments/normal_dim_2/notebooks/self-consistency-abi/
171171
checkpoints/
172172
experiments/normal_dim_2/stan/normal_dim_2
173+
data/
174+
plots/

experiments/hodgkin_huxley/__init__.py

Whitespace-only changes.
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import tensorflow as tf
2+
from bayesflow.amortizers import AmortizedPosterior
3+
4+
from src.self_consistency_real.schedules import (
5+
ConstantSchedule,
6+
)
7+
8+
9+
class AmortizedPosteriorSC(AmortizedPosterior):
10+
def __init__(
11+
self,
12+
prior,
13+
simulator,
14+
real_data,
15+
lambda_schedule=ConstantSchedule(1.0),
16+
n_consistency_samples=32,
17+
theta_clip_value_min=-float("inf"),
18+
theta_clip_value_max=float("inf"),
19+
*args,
20+
**kwargs,
21+
):
22+
super().__init__(*args, **kwargs)
23+
self.prior = prior
24+
self.simulator = simulator
25+
self.real_data = real_data # tf.convert_to_tensor(real_data, dtype=tf.float32)
26+
self.step = tf.Variable(0, trainable=False, dtype=tf.int32)
27+
self.lambda_schedule = lambda_schedule
28+
self.n_consistency_samples = n_consistency_samples
29+
self.theta_clip_value_min = theta_clip_value_min
30+
self.theta_clip_value_max = theta_clip_value_max
31+
32+
def compute_loss(self, input_dict, **kwargs):
33+
self.step.assign_add(1)
34+
lambda_ = self.lambda_schedule(self.step)
35+
36+
# Get amortizer outputs
37+
net_out, sum_out = self(input_dict, return_summary=True, **kwargs)
38+
z, log_det_J = net_out
39+
40+
# Case summary loss should be computed
41+
if self.summary_loss is not None:
42+
sum_loss = self.summary_loss(sum_out)
43+
# Case no summary loss, simply add 0 for convenience
44+
else:
45+
sum_loss = 0.0
46+
47+
# Case dynamic latent space - function of summary conditions
48+
if self.latent_is_dynamic:
49+
logpdf = self.latent_dist(sum_out).log_prob(z)
50+
# Case _static latent space
51+
else:
52+
logpdf = self.latent_dist.log_prob(z)
53+
54+
# Compute and return total posterior loss
55+
posterior_loss = tf.reduce_mean(-logpdf - log_det_J) + sum_loss
56+
57+
# SELF CONSISTENCY LOSS
58+
59+
if tf.greater(lambda_, 0.0):
60+
# x has shape (n_datasets, data_dim)
61+
62+
# indices = tf.stop_gradient(tf.range(tf.shape(self.real_data)[0]))
63+
# prior_draw = tf.random.normal((64, 7), 1, 2)
64+
# x = tf.stop_gradient(self.simulator(prior_draw))
65+
if callable(self.real_data):
66+
x = self.real_data()
67+
else:
68+
x = self.real_data
69+
70+
n_datasets = tf.shape(x)[0]
71+
#
72+
# z shape: n_consistency_samples, n_datasets, data_dim
73+
z = self.latent_dist.sample(
74+
(self.n_consistency_samples, n_datasets), to_numpy=False
75+
)
76+
77+
# add a n_consistency_samples dimension as first (0th) index to x
78+
# conditions shape: n_consistency_samples, n_datasets, summary_dim
79+
data_summary = self.summary_net(x)
80+
data_summary = tf.expand_dims(data_summary, axis=0)
81+
conditions = tf.tile(data_summary, [self.n_consistency_samples, 1, 1])
82+
83+
# x_repeated shape: n_consistency_samples, n_datasets, data_dim
84+
x_reshaped = tf.expand_dims(x, axis=0)
85+
x_repeated = tf.tile(x_reshaped, [self.n_consistency_samples, 1, 1])
86+
87+
# theta shape: n_consistency_samples, n_datasets, n_params
88+
theta = tf.stop_gradient(
89+
self.inference_net.inverse(z, conditions, training=False)
90+
)
91+
92+
# log_prior is log(p(theta)) with shape n_consistency_samples, n_datasets
93+
log_prior = self.prior.log_prob(theta)
94+
95+
# log_lik is log(p(y | theta)) with shape n_consistency_samples, n_datasets
96+
log_lik = tf.stop_gradient(self.simulator.log_prob(theta, x_repeated))
97+
98+
# log_post is log(p(theta | y)) with shape n_consistency_samples, n_datasets
99+
sc_input_dict = {
100+
"parameters": tf.reshape(theta, (-1, tf.shape(theta)[-1])),
101+
"summary_conditions": tf.reshape(
102+
x_repeated, (-1, tf.shape(x_repeated)[-1])
103+
),
104+
}
105+
log_post = self.log_posterior(sc_input_dict, to_numpy=False)
106+
log_post = tf.reshape(log_post, (tf.shape(theta)[:-1]))
107+
108+
# marginal likelihood p(y) = p(theta) * p(y | theta) / p(theta | y)
109+
# shape: n_consistency_samples, n_datasets
110+
log_ml = log_prior + log_lik - log_post
111+
112+
# shape: data_size
113+
log_ml_var = tf.math.reduce_variance(log_ml, axis=-2)
114+
115+
# shape: 1
116+
sc_loss = tf.math.reduce_mean(log_ml_var, axis=-1)
117+
else:
118+
sc_loss = tf.constant(0.0)
119+
120+
return {
121+
"Post.Loss": posterior_loss,
122+
"SC.Loss": tf.multiply(lambda_, sc_loss),
123+
}
Lines changed: 82 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,82 @@
1+
import tensorflow as tf
2+
import tensorflow_probability as tfp
3+
from bayesflow.simulation import GenerativeModel
4+
from .ode import HodgkinHuxleyODE
5+
6+
7+
def get_generative_model():
8+
prior = PriorWithLogProb()
9+
simulator = SimulatorWithLogProb()
10+
11+
model = GenerativeModel(
12+
prior=prior,
13+
simulator=simulator,
14+
prior_is_batched=False,
15+
simulator_is_batched=True,
16+
)
17+
18+
return model
19+
20+
21+
class SimulatorWithLogProb:
22+
def __init__(self):
23+
self.ode = HodgkinHuxleyODE()
24+
25+
def __call__(self, z):
26+
theta = z_to_theta(z)
27+
result = self.ode.solve_ode(theta)
28+
dist = tfp.distributions.StudentT(loc=result, scale=0.1, df=10)
29+
30+
y = dist.sample()
31+
32+
return y
33+
34+
def log_prob(self, z, x):
35+
theta = z_to_theta(z)
36+
y = self.ode.solve_ode(theta)
37+
dist = tfp.distributions.StudentT(loc=y, scale=0.1, df=10)
38+
39+
x_flat = tf.reshape(x, (-1, tf.shape(x)[-1]))
40+
pointwise_log_prob = dist.log_prob(x_flat)
41+
log_prob = tf.reduce_mean(pointwise_log_prob, axis=-1)
42+
log_prob = tf.reshape(log_prob, (tf.shape(x)[:-1]))
43+
44+
return log_prob
45+
46+
47+
class PriorWithLogProb:
48+
def __call__(self):
49+
z = tfp.distributions.Normal(loc=0, scale=1).sample(7)
50+
51+
return z
52+
53+
def log_prob(self, z):
54+
return tf.reduce_sum(
55+
tfp.distributions.Normal(loc=0, scale=1).log_prob(z), axis=-1
56+
)
57+
58+
59+
def theta_to_z(theta):
60+
z_1 = (tf.math.log(theta[..., 0]) - tf.math.log(110.0)) / 0.1
61+
z_2 = (tf.math.log(theta[..., 1]) - tf.math.log(36.0)) / 0.1
62+
z_3 = (tf.math.log(theta[..., 2]) - tf.math.log(0.2)) / 0.5
63+
z_4 = (theta[..., 3] - 1.0) / 0.05
64+
z_5 = (theta[..., 4] + 55.0) / 5.0
65+
z_6 = (theta[..., 5] - 50.0) / 5.0
66+
z_7 = (theta[..., 6] + 77.0) / 5.0
67+
68+
return tf.stack([z_1, z_2, z_3, z_4, z_5, z_6, z_7], axis=-1)
69+
70+
71+
def z_to_theta(z):
72+
theta_1 = tf.exp(tf.math.log(110.0) + 0.1 * z[..., 0])
73+
theta_2 = tf.exp(tf.math.log(36.0) + 0.1 * z[..., 1])
74+
theta_3 = tf.exp(tf.math.log(0.2) + 0.5 * z[..., 2])
75+
theta_4 = z[..., 3] * 0.05 + 1.0
76+
theta_5 = z[..., 4] * 5.0 - 55.0
77+
theta_6 = z[..., 5] * 5.0 + 50.0
78+
theta_7 = z[..., 6] * 5.0 - 77.0
79+
80+
return tf.stack(
81+
[theta_1, theta_2, theta_3, theta_4, theta_5, theta_6, theta_7], axis=-1
82+
)
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import numpy as np
2+
import tensorflow as tf
3+
4+
5+
# mean(abs(mean(y_pred) - y))
6+
def mean_absolute_bias(trainer, y, n_samples=5000):
7+
posterior_draws = trainer.amortizer.sample(
8+
{"summary_conditions": y}, n_samples=n_samples
9+
)
10+
11+
y_pred = trainer.generative_model.simulator(posterior_draws)["sim_data"]
12+
y_pred = tf.reshape(y_pred, [*posterior_draws.shape[:-1], 200]).numpy()
13+
14+
absolute_bias_i = np.abs(np.mean(y_pred, axis=-2) - y)
15+
mean_absolute_bias = np.mean(absolute_bias_i, axis=-1)
16+
17+
return mean_absolute_bias
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
import bayesflow as bf
2+
import tensorflow as tf
3+
import pickle
4+
from pathlib import Path
5+
from .generative_model import get_generative_model
6+
from bayesflow.amortizers import AmortizedPosterior
7+
from .amortized_posterior_sc import AmortizedPosteriorSC
8+
9+
10+
def get_real_data():
11+
file_path = Path(__file__).parents[0] / "data" / "real_data.pkl"
12+
13+
if not file_path.exists():
14+
model = get_generative_model()
15+
prior = tf.random.normal((1024, 7), 0, 2)
16+
data = model.simulator(prior)["sim_data"]
17+
real_data = data # + tf.random.uniform(data.shape, minval=-2.0, maxval=2.0)
18+
19+
with open(file_path, "wb") as file:
20+
pickle.dump(real_data, file)
21+
22+
with open(file_path, "rb") as file:
23+
real_data = pickle.load(file)
24+
25+
return real_data
26+
27+
28+
def get_real_data_subset(n=32):
29+
x = get_real_data()
30+
indices = tf.random.shuffle(tf.range(tf.shape(x)[0]))[:n]
31+
32+
subset = tf.gather(x, indices)
33+
34+
return subset
35+
36+
37+
def get_training_data():
38+
file_path = Path(__file__).parents[0] / "data" / "training_data.pkl"
39+
40+
if not file_path.exists():
41+
model = get_generative_model()
42+
forward_dict = model(2**15)
43+
44+
with open(file_path, "wb") as file:
45+
pickle.dump(forward_dict, file)
46+
47+
with open(file_path, "rb") as file:
48+
forward_dict = pickle.load(file)
49+
50+
return forward_dict
51+
52+
53+
def get_summary_network():
54+
return tf.keras.Sequential(
55+
[
56+
tf.keras.layers.Lambda(lambda x: tf.expand_dims(x, -1)),
57+
tf.keras.layers.LSTM(100),
58+
tf.keras.layers.Dense(400, activation="relu"),
59+
tf.keras.layers.Dense(200, activation="relu"),
60+
tf.keras.layers.Dense(100, activation="relu"),
61+
tf.keras.layers.Dense(50, activation="relu"),
62+
]
63+
)
64+
65+
66+
def get_inference_network():
67+
return bf.networks.InvertibleNetwork(
68+
num_params=7,
69+
num_coupling_layers=10,
70+
coupling_design="spline",
71+
coupling_settings={
72+
"dense_args": {"units": 256},
73+
"kernel_regularizer": tf.keras.regularizers.l2(1e-3),
74+
},
75+
)
76+
77+
78+
def configurator(forward_dict):
79+
input_dict = {}
80+
81+
# expand dims so summary network works on 4-dimensional inputs
82+
input_dict["parameters"] = forward_dict["prior_draws"]
83+
input_dict["summary_conditions"] = forward_dict["sim_data"]
84+
85+
return input_dict
86+
87+
88+
def get_amortizer():
89+
model = get_generative_model()
90+
summary_net = get_summary_network()
91+
inference_net = get_inference_network()
92+
93+
simulator = model.simulator.simulator
94+
prior = model.prior.prior
95+
96+
amortizer = AmortizedPosteriorSC(
97+
prior=prior,
98+
simulator=simulator,
99+
real_data=get_real_data_subset,
100+
inference_net=inference_net,
101+
summary_net=summary_net,
102+
n_consistency_samples=8,
103+
)
104+
105+
return amortizer
106+
107+
108+
def get_trainer(**kwargs):
109+
generative_model = get_generative_model()
110+
amortizer = get_amortizer()
111+
112+
trainer = bf.trainers.Trainer(
113+
amortizer=amortizer,
114+
generative_model=generative_model,
115+
configurator=configurator,
116+
**kwargs,
117+
)
118+
119+
return trainer
120+
121+
122+
def get_trainer_no_sc(**kwargs):
123+
generative_model = get_generative_model()
124+
amortizer = AmortizedPosterior(
125+
inference_net=get_inference_network(), summary_net=get_summary_network()
126+
)
127+
128+
trainer = bf.trainers.Trainer(
129+
amortizer=amortizer,
130+
generative_model=generative_model,
131+
configurator=configurator,
132+
**kwargs,
133+
)
134+
135+
return trainer

0 commit comments

Comments
 (0)