@@ -59,24 +59,30 @@ def brownian_motion_prior_fn(num_timesteps,
59
59
name = 'x_{}' .format (t ))
60
60
61
61
62
- def brownian_motion_unknown_scales_prior_fn (num_timesteps , use_markov_chain ):
62
+ def brownian_motion_unknown_scales_prior_fn (
63
+ num_timesteps , use_markov_chain , dtype
64
+ ):
63
65
"""Generative process for the Brownian Motion model with unknown scales."""
64
- innovation_noise_scale = yield Root (tfd .LogNormal (
65
- 0. , 2. , name = 'innovation_noise_scale' ))
66
- _ = yield Root (tfd .LogNormal (0. , 2. , name = 'observation_noise_scale' ))
66
+ zero = tf .zeros ([], dtype )
67
+ innovation_noise_scale = yield Root (
68
+ tfd .LogNormal (zero , 2.0 , name = 'innovation_noise_scale' )
69
+ )
70
+ _ = yield Root (tfd .LogNormal (zero , 2.0 , name = 'observation_noise_scale' ))
67
71
if use_markov_chain :
68
72
yield brownian_motion_as_markov_chain (
69
73
num_timesteps = num_timesteps ,
70
- innovation_noise_scale = innovation_noise_scale )
74
+ innovation_noise_scale = innovation_noise_scale ,
75
+ )
71
76
else :
72
77
yield from brownian_motion_prior_fn (
73
- num_timesteps ,
74
- innovation_noise_scale = innovation_noise_scale )
78
+ num_timesteps , innovation_noise_scale = innovation_noise_scale
79
+ )
75
80
76
81
77
82
def brownian_motion_log_likelihood_fn (values ,
78
83
observed_locs ,
79
84
use_markov_chain ,
85
+ dtype ,
80
86
observation_noise_scale = None ):
81
87
"""Likelihood of observed data under the Brownian Motion model."""
82
88
if observation_noise_scale is None :
@@ -86,7 +92,12 @@ def brownian_motion_log_likelihood_fn(values,
86
92
latents = values if use_markov_chain else tf .stack (values , axis = - 1 )
87
93
88
94
observation_noise_scale = tf .convert_to_tensor (
89
- observation_noise_scale , name = 'observation_noise_scale' )
95
+ observation_noise_scale , dtype = dtype , name = 'observation_noise_scale' )
96
+ observed_locs = tf .cast (
97
+ observed_locs ,
98
+ dtype = dtype ,
99
+ name = 'observed_locs' ,
100
+ )
90
101
is_observed = ~ tf .math .is_nan (observed_locs )
91
102
lps = tfd .Normal (
92
103
loc = latents , scale = observation_noise_scale [..., tf .newaxis ]).log_prob (
@@ -117,6 +128,7 @@ def __init__(self,
117
128
innovation_noise_scale ,
118
129
observation_noise_scale ,
119
130
use_markov_chain = False ,
131
+ dtype = tf .float32 ,
120
132
name = 'brownian_motion' ,
121
133
pretty_name = 'Brownian Motion' ):
122
134
"""Construct the Brownian Motion model.
@@ -130,11 +142,18 @@ def __init__(self,
130
142
`MarkovChain` distribution in place of separate random variables for
131
143
each time step. The default of `False` is for backwards compatibility;
132
144
setting this to `True` should significantly improve performance.
145
+ dtype: Dtype to use for floating point quantities.
133
146
name: Python `str` name prefixed to Ops created by this class.
134
147
pretty_name: A Python `str`. The pretty name of this model.
135
148
"""
136
149
with tf .name_scope (name ):
137
150
num_timesteps = observed_locs .shape [0 ]
151
+ innovation_noise_scale = tf .convert_to_tensor (
152
+ innovation_noise_scale ,
153
+ dtype = dtype ,
154
+ name = 'innovation_noise_scale' ,
155
+ )
156
+
138
157
if use_markov_chain :
139
158
self ._prior_dist = brownian_motion_as_markov_chain (
140
159
num_timesteps = num_timesteps ,
@@ -150,7 +169,8 @@ def __init__(self,
150
169
brownian_motion_log_likelihood_fn ,
151
170
observation_noise_scale = observation_noise_scale ,
152
171
observed_locs = observed_locs ,
153
- use_markov_chain = use_markov_chain )
172
+ use_markov_chain = use_markov_chain ,
173
+ dtype = dtype )
154
174
155
175
def _ext_identity (params ):
156
176
return tf .stack (params , axis = - 1 )
@@ -164,6 +184,7 @@ def _ext_identity_markov_chain(params):
164
184
fn = (_ext_identity_markov_chain
165
185
if use_markov_chain else _ext_identity ),
166
186
pretty_name = 'Identity' ,
187
+ dtype = dtype ,
167
188
)
168
189
}
169
190
@@ -193,12 +214,13 @@ class BrownianMotionMissingMiddleObservations(BrownianMotion):
193
214
194
215
GROUND_TRUTH_MODULE = brownian_motion_missing_middle_observations
195
216
196
- def __init__ (self , use_markov_chain = False ):
217
+ def __init__ (self , use_markov_chain = False , dtype = tf . float32 ):
197
218
dataset = data .brownian_motion_missing_middle_observations ()
198
219
super (BrownianMotionMissingMiddleObservations , self ).__init__ (
199
220
name = 'brownian_motion_missing_middle_observations' ,
200
221
pretty_name = 'Brownian Motion Missing Middle Observations' ,
201
222
use_markov_chain = use_markov_chain ,
223
+ dtype = dtype ,
202
224
** dataset )
203
225
204
226
@@ -226,6 +248,7 @@ class BrownianMotionUnknownScales(bayesian_model.BayesianModel):
226
248
def __init__ (self ,
227
249
observed_locs ,
228
250
use_markov_chain = False ,
251
+ dtype = tf .float32 ,
229
252
name = 'brownian_motion_unknown_scales' ,
230
253
pretty_name = 'Brownian Motion with Unknown Scales' ):
231
254
"""Construct the Brownian Motion model with unknown scales.
@@ -238,6 +261,7 @@ def __init__(self,
238
261
each time step. The default of `False` is for backwards compatibility;
239
262
setting this to `True` should significantly improve performance.
240
263
Default value: `False`.
264
+ dtype: Dtype to use for floating point quantities.
241
265
name: Python `str` name prefixed to Ops created by this class.
242
266
pretty_name: A Python `str`. The pretty name of this model.
243
267
"""
@@ -247,12 +271,14 @@ def __init__(self,
247
271
functools .partial (
248
272
brownian_motion_unknown_scales_prior_fn ,
249
273
use_markov_chain = use_markov_chain ,
250
- num_timesteps = num_timesteps ))
274
+ num_timesteps = num_timesteps ,
275
+ dtype = dtype ))
251
276
252
277
self ._log_likelihood_fn = functools .partial (
253
278
brownian_motion_log_likelihood_fn ,
254
279
use_markov_chain = use_markov_chain ,
255
- observed_locs = observed_locs )
280
+ observed_locs = observed_locs ,
281
+ dtype = dtype )
256
282
257
283
def _ext_identity (params ):
258
284
return {'innovation_noise_scale' : params [0 ],
@@ -266,9 +292,9 @@ def _ext_identity(params):
266
292
model .Model .SampleTransformation (
267
293
fn = _ext_identity ,
268
294
pretty_name = 'Identity' ,
269
- dtype = {'innovation_noise_scale' : tf . float32 ,
270
- 'observation_noise_scale' : tf . float32 ,
271
- 'locs' : tf . float32 })
295
+ dtype = {'innovation_noise_scale' : dtype ,
296
+ 'observation_noise_scale' : dtype ,
297
+ 'locs' : dtype })
272
298
}
273
299
274
300
event_space_bijector = type (
@@ -300,12 +326,13 @@ class BrownianMotionUnknownScalesMissingMiddleObservations(
300
326
GROUND_TRUTH_MODULE = (
301
327
brownian_motion_unknown_scales_missing_middle_observations )
302
328
303
- def __init__ (self , use_markov_chain = False ):
329
+ def __init__ (self , use_markov_chain = False , dtype = tf . float32 ):
304
330
dataset = data .brownian_motion_missing_middle_observations ()
305
331
del dataset ['innovation_noise_scale' ]
306
332
del dataset ['observation_noise_scale' ]
307
333
super (BrownianMotionUnknownScalesMissingMiddleObservations , self ).__init__ (
308
334
name = 'brownian_motion_unknown_scales_missing_middle_observations' ,
309
335
pretty_name = 'Brownian Motion with Unknown Scales' ,
310
336
use_markov_chain = use_markov_chain ,
337
+ dtype = dtype ,
311
338
** dataset )
0 commit comments