Skip to content

Commit 4c742aa

Browse files
ratgrtensorflower-gardener
authored andcommitted
This change updates the random number generation utility to use tf.uint64 for bitwise operations and constants. The seed generation function (next_seed_fn) has been rewritten to implement a proper 7-bit LFSR (PRBS7) using bitwise operations, replacing the previous power-based approximation.
PiperOrigin-RevId: 915708863
1 parent 1d10a34 commit 4c742aa

2 files changed

Lines changed: 59 additions & 11 deletions

File tree

tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils.py

Lines changed: 50 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -153,33 +153,71 @@ def _cmwc_random_sequence(num_elements, seed):
153153

154154
# Create constants needed for the algorithm. The constants and notation
155155
# follows from the above reference.
156-
a = tf.tile(tf.constant([3636507990], tf.int64), [parallelism])
157-
b = tf.tile(tf.constant([2**32], tf.int64), [parallelism])
158-
logb_scalar = tf.constant(32, tf.int64)
156+
a = tf.tile(tf.constant([3636507990], tf.uint64), [parallelism])
157+
b = tf.tile(tf.constant([2**32], tf.uint64), [parallelism])
158+
logb_scalar = tf.constant(32, tf.uint64)
159159
logb = tf.tile([logb_scalar], [parallelism])
160-
f = tf.tile(tf.constant([0], dtype=tf.int64), [parallelism])
161-
bits = tf.constant(0, dtype=tf.int64, name='bits')
160+
f = tf.tile(tf.constant([0], dtype=tf.uint64), [parallelism])
161+
bits = tf.constant(0, dtype=tf.uint64, name='bits')
162162

163163
# TensorArray used in tf.while_loop for efficiency.
164164
values = tf.TensorArray(
165165
dtype=tf.float64, size=num_iters, element_shape=[parallelism])
166166
# Iteration counter.
167167
num = tf.constant(0, dtype=tf.int32, name='num')
168168
# TensorFlow constant to be used at multiple places.
169-
val_53 = tf.constant(53, tf.int64, name='val_53')
169+
val_53 = tf.constant(53, tf.uint64, name='val_53')
170170

171171
# Construct initial sequence of seeds.
172172
# From a single input seed, we construct multiple starting seeds for the
173173
# sequences to be computed in parallel.
174174
def next_seed_fn(i, val, q):
175-
val = val**7 + val**6 + 1 # PRBS7.
175+
"""Generates the next seed using a 7-bit LFSR.
176+
177+
This function implements a proper 7-bit Fibonacci LFSR with the polynomial
178+
x^7 + x^6 + 1. It takes the lower 7 bits of `val` as the current state,
179+
computes the next state, and writes it to the TensorArray `q`.
180+
181+
Args:
182+
i: The current index in the while loop.
183+
val: The current seed value (tf.uint64). The lower 7 bits are used as
184+
the LFSR state.
185+
q: The tf.TensorArray to write the generated seed into.
186+
187+
Returns:
188+
A tuple of (i + 1, new_val, q), where `new_val` is the next state of the
189+
LFSR.
190+
"""
191+
state = tf.bitwise.bitwise_and(val, tf.constant(0x7F, tf.uint64))
192+
# Avoid zero state, which is a trapping state for this LFSR polynomial.
193+
state = tf.bitwise.bitwise_or(
194+
state,
195+
tf.cast(tf.equal(state, tf.constant(0, tf.uint64)), tf.uint64)
196+
)
197+
# Feedback bit = bit 7 (index 6) ^ bit 6 (index 5)
198+
feedback = tf.bitwise.bitwise_and(
199+
tf.bitwise.bitwise_xor(
200+
tf.bitwise.right_shift(state, tf.constant(6, tf.uint64)),
201+
tf.bitwise.right_shift(state, tf.constant(5, tf.uint64))
202+
),
203+
tf.constant(1, tf.uint64)
204+
)
205+
# Shift left and insert feedback
206+
val = tf.bitwise.bitwise_and(
207+
tf.bitwise.bitwise_or(
208+
tf.bitwise.left_shift(state, tf.constant(1, tf.uint64)),
209+
feedback
210+
),
211+
tf.constant(0x7F, tf.uint64)
212+
)
176213
q = q.write(i, val)
177214
return i + 1, val, q
178215

179-
q = tf.TensorArray(dtype=tf.int64, size=parallelism, element_shape=())
216+
q = tf.TensorArray(dtype=tf.uint64, size=parallelism, element_shape=())
217+
seed_u64 = tf.cast(seed, tf.uint64)
180218
_, _, q = tf.while_loop(lambda i, _, __: i < parallelism,
181219
next_seed_fn,
182-
[tf.constant(0), seed, q])
220+
[tf.constant(0), seed_u64, q])
183221
c = q = q.stack()
184222

185223
# The random sequence generation code.
@@ -193,9 +231,10 @@ def cmwc_step(f, bits, q, c, num, values):
193231
f.set_shape((1,)) # Correct for failed shape inference.
194232
bits += logb_scalar
195233
def add_val(bits, f, values, num):
234+
mask_53 = tf.constant(2**53 - 1, tf.uint64)
196235
new_val = tf.cast(
197-
tf.bitwise.bitwise_and(f, (2**val_53 - 1)),
198-
dtype=tf.float64) * (1 / 2**val_53)
236+
tf.bitwise.bitwise_and(f, mask_53),
237+
dtype=tf.float64) * (1.0 / 2.0**53)
199238
values = values.write(num, new_val)
200239
f += tf.bitwise.right_shift(f, val_53)
201240
bits -= val_53

tensorflow_model_optimization/python/core/internal/tensor_encoding/utils/tf_utils_test.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,15 @@ def test_tf_int32_seed_raises(self):
188188
with self.assertRaisesRegex(TypeError, 'tf.int64 Tensor'):
189189
tf_utils._cmwc_random_sequence(10, tf.constant(123, tf.int32))
190190

191+
def test_reproduction_b511305971(self):
192+
"""Verifies that the PRNG does not produce negative states or bounds violations."""
193+
# Reproduction steps from b/511305971
194+
sequence = tf_utils._cmwc_random_sequence(
195+
1000, tf.constant(12345, dtype=tf.int64))
196+
sequence = self.evaluate(sequence)
197+
self.assertAllGreaterEqual(sequence, 0.0)
198+
self.assertAllLessEqual(sequence, 1.0)
199+
191200

192201
class RandomSignsCMWCTests(tf.test.TestCase, parameterized.TestCase):
193202
"""Tests for `random_signs_cmwc` method."""

0 commit comments

Comments
 (0)