Skip to content

Commit a47479b

Browse files
authored
update default parameters for discrete consistency model (#629)
* update default parameters for discrete consistency model * fix test * norm time in consistency model
1 parent 5156f63 commit a47479b

File tree

3 files changed

+55
-43
lines changed

3 files changed

+55
-43
lines changed

bayesflow/networks/consistency_models/consistency_model.py

Lines changed: 45 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from keras import ops
55

66
from bayesflow.types import Tensor
7-
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, expand_right_as
7+
from bayesflow.utils import find_network, layer_kwargs, weighted_mean, expand_right_as, logging
88
from bayesflow.utils.serialization import deserialize, serializable, serialize
99

1010
from ..inference_network import InferenceNetwork
@@ -40,11 +40,11 @@ def __init__(
4040
self,
4141
total_steps: int | float,
4242
subnet: str | keras.Layer = "time_mlp",
43-
max_time: int | float = 200,
43+
max_time: int | float = 80,
4444
sigma2: float = 1.0,
4545
eps: float = 0.001,
4646
s0: int | float = 10,
47-
s1: int | float = 50,
47+
s1: int | float = 150,
4848
subnet_kwargs: dict[str, any] = None,
4949
**kwargs,
5050
):
@@ -60,15 +60,15 @@ def __init__(
6060
If a string is provided, it should be a registered name (e.g., "time_mlp").
6161
If a type or keras.Layer is provided, it will be directly instantiated
6262
with the given ``subnet_kwargs``. Any subnet must accept a tuple of tensors (target, time, conditions).
63-
max_time : int or float, optional, default: 200.0
64-
The maximum time of the diffusion
63+
max_time : int or float, optional, default: 80
64+
The maximum time of the diffusion, equivalent to the maximum noise level (x_1=z*max_time).
6565
sigma2 : float or Tensor of dimension (input_dim, 1), optional, default: 1.0
6666
Controls the shape of the skip-function
6767
eps : float, optional, default: 0.001
6868
The minimum time
6969
s0 : int or float, optional, default: 10
7070
Initial number of discretization steps
71-
s1 : int or float, optional, default: 50
71+
s1 : int or float, optional, default: 70
7272
Final number of discretization steps
7373
subnet_kwargs: dict[str, any], optional
7474
Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
@@ -90,17 +90,26 @@ def __init__(
9090
self.sigma = ops.sqrt(sigma2)
9191
self.eps = eps
9292
self.max_time = max_time
93-
self.c_huber = None
94-
self.c_huber2 = None
93+
self.rho = float(kwargs.get("rho", 7.0))
94+
self.p_mean = float(kwargs.get("p_mean", -1.1))
95+
self.p_std = float(kwargs.get("p_std", 2.0))
9596

9697
self.s0 = float(s0)
9798
self.s1 = float(s1)
9899

100+
if self.total_steps < self.s0:
101+
raise ValueError(f"total_steps={self.total_steps} must be greater than or equal to s0={self.s0}.")
102+
99103
# create variable that works with JIT compilation
100104
self.current_step = self.add_weight(name="current_step", initializer="zeros", trainable=False, dtype="int")
101105
self.current_step.assign(0)
102106

103107
self.seed_generator = keras.random.SeedGenerator()
108+
self.discretized_times = None
109+
self.discretization_map = None
110+
self.c_huber = None
111+
self.c_huber2 = None
112+
self.unique_n = None
104113

105114
@property
106115
def student(self):
@@ -122,34 +131,36 @@ def get_config(self):
122131
"eps": self.eps,
123132
"s0": self.s0,
124133
"s1": self.s1,
134+
"rho": self.rho,
135+
"p_mean": self.p_mean,
136+
"p_std": self.p_std,
125137
# we do not need to store subnet_kwargs
126138
}
127139

128140
return base_config | serialize(config)
129141

130142
def _schedule_discretization(self, step) -> float:
131-
"""Schedule function for adjusting the discretization level `N` during
143+
"""Schedule function for adjusting the discretization level `N(k)` during
132144
the course of training.
133145
134146
Implements the function N(k) from [2], Section 3.4.
135147
"""
136-
137-
k_ = ops.floor(self.total_steps / (ops.log(self.s1 / self.s0) / ops.log(2.0) + 1.0))
148+
k_ = ops.floor(self.total_steps / (ops.log(ops.floor(self.s1 / self.s0)) / ops.log(2.0) + 1.0))
138149
out = ops.minimum(self.s0 * ops.power(2.0, ops.floor(step / k_)), self.s1) + 1.0
139150
return out
140151

141-
def _discretize_time(self, num_steps, rho=7.0):
152+
def _discretize_time(self, n_k: int) -> Tensor:
142153
"""Function for obtaining the discretized time according to [2],
143154
Section 2, bottom of page 2.
144155
"""
145-
146-
N = num_steps + 1
147-
indices = ops.arange(1, N + 1, dtype="float32")
148-
one_over_rho = 1.0 / rho
156+
indices = ops.arange(1, n_k + 1, dtype="float32")
157+
one_over_rho = 1.0 / self.rho
149158
discretized_time = (
150159
self.eps**one_over_rho
151-
+ (indices - 1.0) / (ops.cast(N, "float32") - 1.0) * (self.max_time**one_over_rho - self.eps**one_over_rho)
152-
) ** rho
160+
+ (indices - 1.0)
161+
/ (ops.cast(n_k, "float32") - 1.0)
162+
* (self.max_time**one_over_rho - self.eps**one_over_rho)
163+
) ** self.rho
153164
return discretized_time
154165

155166
def build(self, xz_shape, conditions_shape=None):
@@ -183,21 +194,20 @@ def build(self, xz_shape, conditions_shape=None):
183194

184195
# First, we calculate all unique numbers of discretization steps n
185196
# in a loop, as self.total_steps might be large
186-
self.max_n = int(self._schedule_discretization(self.total_steps))
187-
188-
if self.max_n != self.s1 + 1:
197+
max_n = int(self._schedule_discretization(self.total_steps))
198+
if max_n != self.s1 + 1:
189199
raise ValueError("The maximum number of discretization steps must be equal to s1 + 1.")
190200

191201
unique_n = set()
192202
for step in range(int(self.total_steps)):
193203
unique_n.add(int(self._schedule_discretization(step)))
194-
unique_n = sorted(list(unique_n))
204+
self.unique_n = sorted(list(unique_n))
195205

196206
# Next, we calculate the discretized times for each n
197207
# and establish a mapping between n and the position i of the
198208
# discretized times in the vector
199-
discretized_times = np.zeros((len(unique_n), self.max_n + 1))
200-
discretization_map = np.zeros((self.max_n + 1,), dtype=np.int32)
209+
discretized_times = np.zeros((len(unique_n), max_n + 1))
210+
discretization_map = np.zeros((max_n + 1,), dtype=np.int32)
201211
for i, n in enumerate(unique_n):
202212
disc = ops.convert_to_numpy(self._discretize_time(n))
203213
discretized_times[i, : len(disc)] = disc
@@ -232,15 +242,20 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
232242
training : bool, optional, default: True
233243
Whether internal layers (e.g., dropout) should behave in train or inference mode.
234244
**kwargs : dict, optional, default: {}
235-
Additional keyword arguments. Include `steps` (default: 10) to
245+
Additional keyword arguments. Include `steps` (default: s0+1) to
236246
adjust the number of sampling steps.
237247
238248
Returns
239249
-------
240250
x : Tensor
241251
The approximate samples
242252
"""
243-
steps = kwargs.get("steps", 10)
253+
steps = int(kwargs.get("steps", self.s0 + 1))
254+
if steps not in self.unique_n:
255+
logging.warning(
256+
"The number of discretization steps is not equal to the number of unique steps used during training. "
257+
"This might lead to suboptimal sample quality."
258+
)
244259
x = keras.ops.copy(z) * self.max_time
245260
discretized_time = keras.ops.flip(self._discretize_time(steps), axis=-1)
246261
t = keras.ops.full((*keras.ops.shape(x)[:-1], 1), discretized_time[0], dtype=x.dtype)
@@ -268,7 +283,7 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
268283
training : bool, optional, default: True
269284
Whether internal layers (e.g., dropout) should behave in train or inference mode.
270285
"""
271-
subnet_out = self.subnet((x, t, conditions), training=training)
286+
subnet_out = self.subnet((x, t / self.max_time, conditions), training=training)
272287
f = self.output_projector(subnet_out)
273288

274289
# Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
@@ -298,12 +313,10 @@ def compute_metrics(
298313

299314
# Randomly sample t_n and t_[n+1] and reshape to (batch_size, 1)
300315
# adapted noise schedule from [2], Section 3.5
301-
p_mean = -1.1
302-
p_std = 2.0
303316
p = ops.where(
304317
discretized_time[1:] > 0.0,
305-
ops.erf((ops.log(discretized_time[1:]) - p_mean) / (ops.sqrt(2.0) * p_std))
306-
- ops.erf((ops.log(discretized_time[:-1]) - p_mean) / (ops.sqrt(2.0) * p_std)),
318+
ops.erf((ops.log(discretized_time[1:]) - self.p_mean) / (ops.sqrt(2.0) * self.p_std))
319+
- ops.erf((ops.log(discretized_time[:-1]) - self.p_mean) / (ops.sqrt(2.0) * self.p_std)),
307320
0.0,
308321
)
309322

@@ -316,8 +329,7 @@ def compute_metrics(
316329
noise = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator)
317330

318331
teacher_out = self._forward_train(x, noise, t1, conditions=conditions, training=stage == "training")
319-
# difference between teacher and student: different time,
320-
# and no gradient for the teacher
332+
# difference between teacher and student: different time, and no gradient for the teacher
321333
teacher_out = ops.stop_gradient(teacher_out)
322334
student_out = self._forward_train(x, noise, t2, conditions=conditions, training=stage == "training")
323335

bayesflow/networks/consistency_models/stable_consistency_model.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,9 @@ def __init__(
9494
)
9595

9696
self.sigma = sigma
97+
self.p_mean = float(kwargs.get("p_mean", -1.0))
98+
self.p_std = float(kwargs.get("p_std", 1.6))
99+
self.c = float(kwargs.get("c", 0.1))
97100
self.seed_generator = keras.random.SeedGenerator()
98101

99102
@classmethod
@@ -107,6 +110,9 @@ def get_config(self):
107110
config = {
108111
"subnet": self.subnet,
109112
"sigma": self.sigma,
113+
"p_mean": self.p_mean,
114+
"p_std": self.p_std,
115+
"c": self.c,
110116
}
111117

112118
return base_config | serialize(config)
@@ -220,19 +226,13 @@ def compute_metrics(
220226

221227
# $# Implements Algorithm 1 from [1]
222228

223-
# training parameters
224-
p_mean = -1.0
225-
p_std = 1.6
226-
227-
c = 0.1
228-
229229
# generate noise vector
230230
z = keras.random.normal(keras.ops.shape(x), dtype=keras.ops.dtype(x), seed=self.seed_generator) * self.sigma
231231

232232
# sample time
233233
tau = (
234-
keras.random.normal(keras.ops.shape(x)[:1], dtype=keras.ops.dtype(x), seed=self.seed_generator) * p_std
235-
+ p_mean
234+
keras.random.normal(keras.ops.shape(x)[:1], dtype=keras.ops.dtype(x), seed=self.seed_generator) * self.p_std
235+
+ self.p_mean
236236
)
237237
t_ = ops.arctan(ops.exp(tau) / self.sigma)
238238
t = expand_right_as(t_, x)
@@ -269,7 +269,7 @@ def f_teacher(x, t):
269269
)
270270

271271
# apply normalization to stabilize training
272-
g = g / (ops.norm(g, axis=-1, keepdims=True) + c)
272+
g = g / (ops.norm(g, axis=-1, keepdims=True) + self.c)
273273

274274
# compute adaptive weights and calculate loss
275275
w = self.weight_fn_projector(self.weight_fn(expand_right_to(t_, 2)))

tests/test_workflows/conftest.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def inference_network(request):
2525
elif request.param == "consistency_model":
2626
from bayesflow.networks import ConsistencyModel
2727

28-
return ConsistencyModel(subnet_kwargs=dict(widths=(16, 16)), total_steps=4)
28+
return ConsistencyModel(subnet_kwargs=dict(widths=(16, 16)), total_steps=10)
2929

3030

3131
@pytest.fixture(params=["time_series_transformer", "fusion_transformer", "time_series_network", "custom"])

0 commit comments

Comments
 (0)