44from keras import ops
55
66from bayesflow .types import Tensor
7- from bayesflow .utils import find_network , layer_kwargs , weighted_mean , expand_right_as
7+ from bayesflow .utils import find_network , layer_kwargs , weighted_mean , expand_right_as , logging
88from bayesflow .utils .serialization import deserialize , serializable , serialize
99
1010from ..inference_network import InferenceNetwork
@@ -40,11 +40,11 @@ def __init__(
4040 self ,
4141 total_steps : int | float ,
4242 subnet : str | keras .Layer = "time_mlp" ,
43- max_time : int | float = 200 ,
43+ max_time : int | float = 80 ,
4444 sigma2 : float = 1.0 ,
4545 eps : float = 0.001 ,
4646 s0 : int | float = 10 ,
47- s1 : int | float = 50 ,
47+ s1 : int | float = 150 ,
4848 subnet_kwargs : dict [str , any ] = None ,
4949 ** kwargs ,
5050 ):
@@ -60,15 +60,15 @@ def __init__(
6060 If a string is provided, it should be a registered name (e.g., "time_mlp").
6161 If a type or keras.Layer is provided, it will be directly instantiated
6262 with the given ``subnet_kwargs``. Any subnet must accept a tuple of tensors (target, time, conditions).
63- max_time : int or float, optional, default: 200.0
64- The maximum time of the diffusion
63+ max_time : int or float, optional, default: 80
64+ The maximum time of the diffusion, equivalent to the maximum noise level (x_1=z*max_time).
6565 sigma2 : float or Tensor of dimension (input_dim, 1), optional, default: 1.0
6666 Controls the shape of the skip-function
6767 eps : float, optional, default: 0.001
6868 The minimum time
6969 s0 : int or float, optional, default: 10
7070 Initial number of discretization steps
71- s1 : int or float, optional, default: 50
71+ s1 : int or float, optional, default: 70
7272 Final number of discretization steps
7373 subnet_kwargs: dict[str, any], optional
7474 Keyword arguments passed to the subnet constructor or used to update the default MLP settings.
@@ -90,17 +90,26 @@ def __init__(
9090 self .sigma = ops .sqrt (sigma2 )
9191 self .eps = eps
9292 self .max_time = max_time
93- self .c_huber = None
94- self .c_huber2 = None
93+ self .rho = float (kwargs .get ("rho" , 7.0 ))
94+ self .p_mean = float (kwargs .get ("p_mean" , - 1.1 ))
95+ self .p_std = float (kwargs .get ("p_std" , 2.0 ))
9596
9697 self .s0 = float (s0 )
9798 self .s1 = float (s1 )
9899
100+ if self .total_steps < self .s0 :
101+ raise ValueError (f"total_steps={ self .total_steps } must be greater than or equal to s0={ self .s0 } ." )
102+
99103 # create variable that works with JIT compilation
100104 self .current_step = self .add_weight (name = "current_step" , initializer = "zeros" , trainable = False , dtype = "int" )
101105 self .current_step .assign (0 )
102106
103107 self .seed_generator = keras .random .SeedGenerator ()
108+ self .discretized_times = None
109+ self .discretization_map = None
110+ self .c_huber = None
111+ self .c_huber2 = None
112+ self .unique_n = None
104113
105114 @property
106115 def student (self ):
@@ -122,34 +131,36 @@ def get_config(self):
122131 "eps" : self .eps ,
123132 "s0" : self .s0 ,
124133 "s1" : self .s1 ,
134+ "rho" : self .rho ,
135+ "p_mean" : self .p_mean ,
136+ "p_std" : self .p_std ,
125137 # we do not need to store subnet_kwargs
126138 }
127139
128140 return base_config | serialize (config )
129141
130142 def _schedule_discretization (self , step ) -> float :
131- """Schedule function for adjusting the discretization level `N` during
143+ """Schedule function for adjusting the discretization level `N(k) ` during
132144 the course of training.
133145
134146 Implements the function N(k) from [2], Section 3.4.
135147 """
136-
137- k_ = ops .floor (self .total_steps / (ops .log (self .s1 / self .s0 ) / ops .log (2.0 ) + 1.0 ))
148+ k_ = ops .floor (self .total_steps / (ops .log (ops .floor (self .s1 / self .s0 )) / ops .log (2.0 ) + 1.0 ))
138149 out = ops .minimum (self .s0 * ops .power (2.0 , ops .floor (step / k_ )), self .s1 ) + 1.0
139150 return out
140151
141- def _discretize_time (self , num_steps , rho = 7.0 ) :
152+ def _discretize_time (self , n_k : int ) -> Tensor :
142153 """Function for obtaining the discretized time according to [2],
143154 Section 2, bottom of page 2.
144155 """
145-
146- N = num_steps + 1
147- indices = ops .arange (1 , N + 1 , dtype = "float32" )
148- one_over_rho = 1.0 / rho
156+ indices = ops .arange (1 , n_k + 1 , dtype = "float32" )
157+ one_over_rho = 1.0 / self .rho
149158 discretized_time = (
150159 self .eps ** one_over_rho
151- + (indices - 1.0 ) / (ops .cast (N , "float32" ) - 1.0 ) * (self .max_time ** one_over_rho - self .eps ** one_over_rho )
152- ) ** rho
160+ + (indices - 1.0 )
161+ / (ops .cast (n_k , "float32" ) - 1.0 )
162+ * (self .max_time ** one_over_rho - self .eps ** one_over_rho )
163+ ) ** self .rho
153164 return discretized_time
154165
155166 def build (self , xz_shape , conditions_shape = None ):
@@ -183,21 +194,20 @@ def build(self, xz_shape, conditions_shape=None):
183194
184195 # First, we calculate all unique numbers of discretization steps n
185196 # in a loop, as self.total_steps might be large
186- self .max_n = int (self ._schedule_discretization (self .total_steps ))
187-
188- if self .max_n != self .s1 + 1 :
197+ max_n = int (self ._schedule_discretization (self .total_steps ))
198+ if max_n != self .s1 + 1 :
189199 raise ValueError ("The maximum number of discretization steps must be equal to s1 + 1." )
190200
191201 unique_n = set ()
192202 for step in range (int (self .total_steps )):
193203 unique_n .add (int (self ._schedule_discretization (step )))
194- unique_n = sorted (list (unique_n ))
204+ self . unique_n = sorted (list (unique_n ))
195205
196206 # Next, we calculate the discretized times for each n
197207 # and establish a mapping between n and the position i of the
198208 # discretized times in the vector
199- discretized_times = np .zeros ((len (unique_n ), self . max_n + 1 ))
200- discretization_map = np .zeros ((self . max_n + 1 ,), dtype = np .int32 )
209+ discretized_times = np .zeros ((len (unique_n ), max_n + 1 ))
210+ discretization_map = np .zeros ((max_n + 1 ,), dtype = np .int32 )
201211 for i , n in enumerate (unique_n ):
202212 disc = ops .convert_to_numpy (self ._discretize_time (n ))
203213 discretized_times [i , : len (disc )] = disc
@@ -232,15 +242,20 @@ def _inverse(self, z: Tensor, conditions: Tensor = None, training: bool = False,
232242 training : bool, optional, default: True
233243 Whether internal layers (e.g., dropout) should behave in train or inference mode.
234244 **kwargs : dict, optional, default: {}
235- Additional keyword arguments. Include `steps` (default: 10 ) to
245+ Additional keyword arguments. Include `steps` (default: s0+1 ) to
236246 adjust the number of sampling steps.
237247
238248 Returns
239249 -------
240250 x : Tensor
241251 The approximate samples
242252 """
243- steps = kwargs .get ("steps" , 10 )
253+ steps = int (kwargs .get ("steps" , self .s0 + 1 ))
254+ if steps not in self .unique_n :
255+ logging .warning (
256+ "The number of discretization steps is not equal to the number of unique steps used during training. "
257+ "This might lead to suboptimal sample quality."
258+ )
244259 x = keras .ops .copy (z ) * self .max_time
245260 discretized_time = keras .ops .flip (self ._discretize_time (steps ), axis = - 1 )
246261 t = keras .ops .full ((* keras .ops .shape (x )[:- 1 ], 1 ), discretized_time [0 ], dtype = x .dtype )
@@ -268,7 +283,7 @@ def consistency_function(self, x: Tensor, t: Tensor, conditions: Tensor = None,
268283 training : bool, optional, default: True
269284 Whether internal layers (e.g., dropout) should behave in train or inference mode.
270285 """
271- subnet_out = self .subnet ((x , t , conditions ), training = training )
286+ subnet_out = self .subnet ((x , t / self . max_time , conditions ), training = training )
272287 f = self .output_projector (subnet_out )
273288
274289 # Compute skip and out parts (vectorized, since self.sigma2 is of shape (1, input_dim)
@@ -298,12 +313,10 @@ def compute_metrics(
298313
299314 # Randomly sample t_n and t_[n+1] and reshape to (batch_size, 1)
300315 # adapted noise schedule from [2], Section 3.5
301- p_mean = - 1.1
302- p_std = 2.0
303316 p = ops .where (
304317 discretized_time [1 :] > 0.0 ,
305- ops .erf ((ops .log (discretized_time [1 :]) - p_mean ) / (ops .sqrt (2.0 ) * p_std ))
306- - ops .erf ((ops .log (discretized_time [:- 1 ]) - p_mean ) / (ops .sqrt (2.0 ) * p_std )),
318+ ops .erf ((ops .log (discretized_time [1 :]) - self . p_mean ) / (ops .sqrt (2.0 ) * self . p_std ))
319+ - ops .erf ((ops .log (discretized_time [:- 1 ]) - self . p_mean ) / (ops .sqrt (2.0 ) * self . p_std )),
307320 0.0 ,
308321 )
309322
@@ -316,8 +329,7 @@ def compute_metrics(
316329 noise = keras .random .normal (keras .ops .shape (x ), dtype = keras .ops .dtype (x ), seed = self .seed_generator )
317330
318331 teacher_out = self ._forward_train (x , noise , t1 , conditions = conditions , training = stage == "training" )
319- # difference between teacher and student: different time,
320- # and no gradient for the teacher
332+ # difference between teacher and student: different time, and no gradient for the teacher
321333 teacher_out = ops .stop_gradient (teacher_out )
322334 student_out = self ._forward_train (x , noise , t2 , conditions = conditions , training = stage == "training" )
323335
0 commit comments