Skip to content

Commit 2297429

Browse files
SiegeLordExtensorflower-gardener
authored andcommitted
Inference Gym: Add 64 bit support to all targets.
This is done via adding a `dtype` argument to the initializer which is then used to cast the inputs (if any) as well as the dtype for de novo arrays (like tf.zeros). This is a departure from the usual TensorFlow Probability style where the dtype is inferred from the inputs because: - Inference Gym targets often have no inputs - Inference Gym is not designed to do deferred array materialization like TFP is We also depart from the typical JAX style of using the maximum precision available because we want to enable testing 32 and 64 bit implementations in the same process. NOTE: At this time, ground truth remains always numpy arrays and always 64 bit. Fixes #1993 PiperOrigin-RevId: 736231310
1 parent b688014 commit 2297429

35 files changed

+597
-269
lines changed

spinoffs/inference_gym/inference_gym/internal/test_util.py

+13-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
import tensorflow.compat.v2 as tf
2424
import tensorflow_probability as tfp
2525

26+
from tensorflow_probability.python.internal import dtype_util
2627
from tensorflow_probability.python.internal import test_util
2728

2829
flags.DEFINE_bool('use_tfds', False, 'Whether to run tests that use TFDS.',
@@ -279,6 +280,7 @@ def validate_log_prob_and_transforms(
279280
self,
280281
model,
281282
sample_transformation_shapes,
283+
dtype=tf.float32,
282284
check_ground_truth_mean=False,
283285
check_ground_truth_mean_standard_error=False,
284286
check_ground_truth_standard_deviation=False,
@@ -295,6 +297,7 @@ def validate_log_prob_and_transforms(
295297
Args:
296298
model: The model to validate.
297299
sample_transformation_shapes: Shapes of the transformation outputs.
300+
dtype: The expected dtype of floating point quantities.
298301
check_ground_truth_mean: Whether to check the shape of the ground truth
299302
mean.
300303
check_ground_truth_mean_standard_error: Whether to check the shape of the
@@ -331,16 +334,25 @@ def _random_element(shape, dtype, default_event_space_bijector, seed):
331334

332335
self.assertAllFinite(log_prob)
333336
self.assertEqual((batch_size,), log_prob.shape)
337+
self.assertEqual(dtype, log_prob.dtype)
338+
339+
def _assert_dtype_part(part):
340+
if dtype_util.is_floating(part):
341+
self.assertEqual(dtype, part)
342+
343+
self.assertAllAssertsNested(_assert_dtype_part, model.dtype)
334344

335345
for name, sample_transformation in model.sample_transformations.items():
336-
transformed_points = self.evaluate(sample_transformation(test_points))
346+
transformed_points = sample_transformation(test_points)
337347

338348
def _assertions_part(expected_shape, expected_dtype, transformed_part):
339349
self.assertAllFinite(transformed_part)
340350
self.assertEqual(
341351
(batch_size,) + tuple(expected_shape),
342352
tuple(list(transformed_part.shape)))
343353
self.assertEqual(expected_dtype, transformed_part.dtype)
354+
if dtype_util.is_floating(transformed_part.dtype):
355+
self.assertEqual(dtype, transformed_part.dtype)
344356

345357
self.assertAllAssertsNested(
346358
_assertions_part,

spinoffs/inference_gym/inference_gym/targets/BUILD

+8
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ py_library(
6363
srcs = ["banana.py"],
6464
deps = [
6565
":model",
66+
# absl/testing:parameterized dep,
6667
# numpy dep,
6768
],
6869
)
@@ -73,6 +74,7 @@ py_test(
7374
srcs = ["banana_test.py"],
7475
deps = [
7576
":banana",
77+
# absl/testing:parameterized dep,
7678
# tensorflow dep,
7779
# tensorflow_probability/python/internal:test_util dep,
7880
"//inference_gym/internal:test_util",
@@ -146,6 +148,7 @@ py_test(
146148
srcs = ["eight_schools_test.py"],
147149
deps = [
148150
":eight_schools",
151+
# absl/testing:parameterized dep,
149152
# tensorflow dep,
150153
# tensorflow_probability/python/internal:test_util dep,
151154
"//inference_gym/internal:test_util",
@@ -167,6 +170,7 @@ py_test(
167170
srcs = ["ill_conditioned_gaussian_test.py"],
168171
deps = [
169172
":ill_conditioned_gaussian",
173+
# absl/testing:parameterized dep,
170174
# tensorflow dep,
171175
# tensorflow_probability/python/internal:test_util dep,
172176
"//inference_gym/internal:test_util",
@@ -220,6 +224,7 @@ py_test(
220224
shard_count = 3,
221225
deps = [
222226
":log_gaussian_cox_process",
227+
# absl/testing:parameterized dep,
223228
# numpy dep,
224229
# tensorflow dep,
225230
# tensorflow_probability/python/internal:test_util dep,
@@ -275,6 +280,7 @@ py_test(
275280
shard_count = 3,
276281
deps = [
277282
":lorenz_system",
283+
# absl/testing:parameterized dep,
278284
# numpy dep,
279285
# tensorflow dep,
280286
# tensorflow_probability/python/internal:test_util dep,
@@ -319,6 +325,7 @@ py_test(
319325
srcs = ["neals_funnel_test.py"],
320326
deps = [
321327
":neals_funnel",
328+
# absl/testing:parameterized dep,
322329
# tensorflow dep,
323330
# tensorflow_probability/python/internal:test_util dep,
324331
"//inference_gym/internal:test_util",
@@ -340,6 +347,7 @@ py_test(
340347
srcs = ["non_identifiable_quartic_test.py"],
341348
deps = [
342349
":non_identifiable_quartic",
350+
# absl/testing:parameterized dep,
343351
# tensorflow dep,
344352
# tensorflow_probability/python/internal:test_util dep,
345353
"//inference_gym/internal:test_util",

spinoffs/inference_gym/inference_gym/targets/banana.py

+11-4
Original file line numberDiff line numberDiff line change
@@ -61,16 +61,18 @@ def __init__(
6161
self,
6262
ndims=2,
6363
curvature=0.03,
64+
dtype=tf.float32,
6465
name='banana',
6566
pretty_name='Banana',
6667
):
67-
"""Construct the banana model.
68+
"""Initialize the banana model.
6869
6970
Args:
7071
ndims: Python integer. Dimensionality of the distribution. Must be at
7172
least 2.
7273
curvature: Python float. Controls the strength of the curvature of
7374
the distribution.
75+
dtype: Dtype to use for floating point quantities.
7476
name: Python `str` name prefixed to Ops created by this class.
7577
pretty_name: A Python `str`. The pretty name of this model.
7678
@@ -87,16 +89,20 @@ def bijector_fn(x):
8789
batch_shape = ps.shape(x)[:-1]
8890
shift = tf.concat(
8991
[
90-
tf.zeros(ps.concat([batch_shape, [1]], axis=0)),
92+
tf.zeros(ps.concat([batch_shape, [1]], axis=0), dtype),
9193
curvature * (tf.square(x[..., :1]) - 100),
92-
tf.zeros(ps.concat([batch_shape, [ndims - 2]], axis=0)),
94+
tf.zeros(
95+
ps.concat([batch_shape, [ndims - 2]], axis=0), dtype
96+
),
9397
],
9498
axis=-1,
9599
)
96100
return tfb.Shift(shift)
97101

98102
mg = tfd.MultivariateNormalDiag(
99-
loc=tf.zeros(ndims), scale_diag=[10.] + [1.] * (ndims - 1))
103+
loc=tf.zeros(ndims, dtype),
104+
scale_diag=[10.0] + [1.0] * (ndims - 1),
105+
)
100106
banana = tfd.TransformedDistribution(
101107
mg, bijector=tfb.MaskedAutoregressiveFlow(bijector_fn=bijector_fn))
102108

@@ -115,6 +121,7 @@ def bijector_fn(x):
115121
ground_truth_standard_deviation=np.array(
116122
[10.] + [np.sqrt(1. + 2 * curvature**2 * 10.**4)] +
117123
[1.] * (ndims - 2)),
124+
dtype=banana.dtype,
118125
)
119126
}
120127

spinoffs/inference_gym/inference_gym/targets/banana_test.py

+10-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,9 @@
1414
# ============================================================================
1515
"""Tests for inference_gym.targets.banana."""
1616

17+
from absl.testing import parameterized
18+
import tensorflow.compat.v2 as tf
19+
1720
from tensorflow_probability.python.internal import test_util as tfp_test_util
1821
from inference_gym.internal import test_util
1922
from inference_gym.targets import banana
@@ -22,18 +25,23 @@
2225
@test_util.multi_backend_test(globals(), 'targets.banana_test')
2326
class BananaTest(test_util.InferenceGymTestCase):
2427

25-
def testBasic(self):
28+
@parameterized.parameters(tf.float32, tf.float64)
29+
def testBasic(self, dtype):
2630
"""Checks that you get finite values given unconstrained samples.
2731
2832
We check `unnormalized_log_prob` as well as the values of the sample
2933
transformations.
34+
35+
Args:
36+
dtype: Dtype to use for floating point computations.
3037
"""
31-
model = banana.Banana(ndims=3)
38+
model = banana.Banana(ndims=3, dtype=dtype)
3239
self.validate_log_prob_and_transforms(
3340
model,
3441
sample_transformation_shapes=dict(identity=[3]),
3542
check_ground_truth_mean=True,
3643
check_ground_truth_standard_deviation=True,
44+
dtype=dtype,
3745
)
3846

3947
def testMC(self):

spinoffs/inference_gym/inference_gym/targets/brownian_motion.py

+43-16
Original file line numberDiff line numberDiff line change
@@ -59,24 +59,30 @@ def brownian_motion_prior_fn(num_timesteps,
5959
name='x_{}'.format(t))
6060

6161

62-
def brownian_motion_unknown_scales_prior_fn(num_timesteps, use_markov_chain):
62+
def brownian_motion_unknown_scales_prior_fn(
63+
num_timesteps, use_markov_chain, dtype
64+
):
6365
"""Generative process for the Brownian Motion model with unknown scales."""
64-
innovation_noise_scale = yield Root(tfd.LogNormal(
65-
0., 2., name='innovation_noise_scale'))
66-
_ = yield Root(tfd.LogNormal(0., 2., name='observation_noise_scale'))
66+
zero = tf.zeros([], dtype)
67+
innovation_noise_scale = yield Root(
68+
tfd.LogNormal(zero, 2.0, name='innovation_noise_scale')
69+
)
70+
_ = yield Root(tfd.LogNormal(zero, 2.0, name='observation_noise_scale'))
6771
if use_markov_chain:
6872
yield brownian_motion_as_markov_chain(
6973
num_timesteps=num_timesteps,
70-
innovation_noise_scale=innovation_noise_scale)
74+
innovation_noise_scale=innovation_noise_scale,
75+
)
7176
else:
7277
yield from brownian_motion_prior_fn(
73-
num_timesteps,
74-
innovation_noise_scale=innovation_noise_scale)
78+
num_timesteps, innovation_noise_scale=innovation_noise_scale
79+
)
7580

7681

7782
def brownian_motion_log_likelihood_fn(values,
7883
observed_locs,
7984
use_markov_chain,
85+
dtype,
8086
observation_noise_scale=None):
8187
"""Likelihood of observed data under the Brownian Motion model."""
8288
if observation_noise_scale is None:
@@ -86,7 +92,12 @@ def brownian_motion_log_likelihood_fn(values,
8692
latents = values if use_markov_chain else tf.stack(values, axis=-1)
8793

8894
observation_noise_scale = tf.convert_to_tensor(
89-
observation_noise_scale, name='observation_noise_scale')
95+
observation_noise_scale, dtype=dtype, name='observation_noise_scale')
96+
observed_locs = tf.cast(
97+
observed_locs,
98+
dtype=dtype,
99+
name='observed_locs',
100+
)
90101
is_observed = ~tf.math.is_nan(observed_locs)
91102
lps = tfd.Normal(
92103
loc=latents, scale=observation_noise_scale[..., tf.newaxis]).log_prob(
@@ -117,6 +128,7 @@ def __init__(self,
117128
innovation_noise_scale,
118129
observation_noise_scale,
119130
use_markov_chain=False,
131+
dtype=tf.float32,
120132
name='brownian_motion',
121133
pretty_name='Brownian Motion'):
122134
"""Construct the Brownian Motion model.
@@ -130,11 +142,18 @@ def __init__(self,
130142
`MarkovChain` distribution in place of separate random variables for
131143
each time step. The default of `False` is for backwards compatibility;
132144
setting this to `True` should significantly improve performance.
145+
dtype: Dtype to use for floating point quantities.
133146
name: Python `str` name prefixed to Ops created by this class.
134147
pretty_name: A Python `str`. The pretty name of this model.
135148
"""
136149
with tf.name_scope(name):
137150
num_timesteps = observed_locs.shape[0]
151+
innovation_noise_scale = tf.convert_to_tensor(
152+
innovation_noise_scale,
153+
dtype=dtype,
154+
name='innovation_noise_scale',
155+
)
156+
138157
if use_markov_chain:
139158
self._prior_dist = brownian_motion_as_markov_chain(
140159
num_timesteps=num_timesteps,
@@ -150,7 +169,8 @@ def __init__(self,
150169
brownian_motion_log_likelihood_fn,
151170
observation_noise_scale=observation_noise_scale,
152171
observed_locs=observed_locs,
153-
use_markov_chain=use_markov_chain)
172+
use_markov_chain=use_markov_chain,
173+
dtype=dtype)
154174

155175
def _ext_identity(params):
156176
return tf.stack(params, axis=-1)
@@ -164,6 +184,7 @@ def _ext_identity_markov_chain(params):
164184
fn=(_ext_identity_markov_chain
165185
if use_markov_chain else _ext_identity),
166186
pretty_name='Identity',
187+
dtype=dtype,
167188
)
168189
}
169190

@@ -193,12 +214,13 @@ class BrownianMotionMissingMiddleObservations(BrownianMotion):
193214

194215
GROUND_TRUTH_MODULE = brownian_motion_missing_middle_observations
195216

196-
def __init__(self, use_markov_chain=False):
217+
def __init__(self, use_markov_chain=False, dtype=tf.float32):
197218
dataset = data.brownian_motion_missing_middle_observations()
198219
super(BrownianMotionMissingMiddleObservations, self).__init__(
199220
name='brownian_motion_missing_middle_observations',
200221
pretty_name='Brownian Motion Missing Middle Observations',
201222
use_markov_chain=use_markov_chain,
223+
dtype=dtype,
202224
**dataset)
203225

204226

@@ -226,6 +248,7 @@ class BrownianMotionUnknownScales(bayesian_model.BayesianModel):
226248
def __init__(self,
227249
observed_locs,
228250
use_markov_chain=False,
251+
dtype=tf.float32,
229252
name='brownian_motion_unknown_scales',
230253
pretty_name='Brownian Motion with Unknown Scales'):
231254
"""Construct the Brownian Motion model with unknown scales.
@@ -238,6 +261,7 @@ def __init__(self,
238261
each time step. The default of `False` is for backwards compatibility;
239262
setting this to `True` should significantly improve performance.
240263
Default value: `False`.
264+
dtype: Dtype to use for floating point quantities.
241265
name: Python `str` name prefixed to Ops created by this class.
242266
pretty_name: A Python `str`. The pretty name of this model.
243267
"""
@@ -247,12 +271,14 @@ def __init__(self,
247271
functools.partial(
248272
brownian_motion_unknown_scales_prior_fn,
249273
use_markov_chain=use_markov_chain,
250-
num_timesteps=num_timesteps))
274+
num_timesteps=num_timesteps,
275+
dtype=dtype))
251276

252277
self._log_likelihood_fn = functools.partial(
253278
brownian_motion_log_likelihood_fn,
254279
use_markov_chain=use_markov_chain,
255-
observed_locs=observed_locs)
280+
observed_locs=observed_locs,
281+
dtype=dtype)
256282

257283
def _ext_identity(params):
258284
return {'innovation_noise_scale': params[0],
@@ -266,9 +292,9 @@ def _ext_identity(params):
266292
model.Model.SampleTransformation(
267293
fn=_ext_identity,
268294
pretty_name='Identity',
269-
dtype={'innovation_noise_scale': tf.float32,
270-
'observation_noise_scale': tf.float32,
271-
'locs': tf.float32})
295+
dtype={'innovation_noise_scale': dtype,
296+
'observation_noise_scale': dtype,
297+
'locs': dtype})
272298
}
273299

274300
event_space_bijector = type(
@@ -300,12 +326,13 @@ class BrownianMotionUnknownScalesMissingMiddleObservations(
300326
GROUND_TRUTH_MODULE = (
301327
brownian_motion_unknown_scales_missing_middle_observations)
302328

303-
def __init__(self, use_markov_chain=False):
329+
def __init__(self, use_markov_chain=False, dtype=tf.float32):
304330
dataset = data.brownian_motion_missing_middle_observations()
305331
del dataset['innovation_noise_scale']
306332
del dataset['observation_noise_scale']
307333
super(BrownianMotionUnknownScalesMissingMiddleObservations, self).__init__(
308334
name='brownian_motion_unknown_scales_missing_middle_observations',
309335
pretty_name='Brownian Motion with Unknown Scales',
310336
use_markov_chain=use_markov_chain,
337+
dtype=dtype,
311338
**dataset)

0 commit comments

Comments
 (0)