@@ -153,33 +153,71 @@ def _cmwc_random_sequence(num_elements, seed):
153153
154154 # Create constants needed for the algorithm. The constants and notation
155155 # follows from the above reference.
156- a = tf .tile (tf .constant ([3636507990 ], tf .int64 ), [parallelism ])
157- b = tf .tile (tf .constant ([2 ** 32 ], tf .int64 ), [parallelism ])
158- logb_scalar = tf .constant (32 , tf .int64 )
156+ a = tf .tile (tf .constant ([3636507990 ], tf .uint64 ), [parallelism ])
157+ b = tf .tile (tf .constant ([2 ** 32 ], tf .uint64 ), [parallelism ])
158+ logb_scalar = tf .constant (32 , tf .uint64 )
159159 logb = tf .tile ([logb_scalar ], [parallelism ])
160- f = tf .tile (tf .constant ([0 ], dtype = tf .int64 ), [parallelism ])
161- bits = tf .constant (0 , dtype = tf .int64 , name = 'bits' )
160+ f = tf .tile (tf .constant ([0 ], dtype = tf .uint64 ), [parallelism ])
161+ bits = tf .constant (0 , dtype = tf .uint64 , name = 'bits' )
162162
163163 # TensorArray used in tf.while_loop for efficiency.
164164 values = tf .TensorArray (
165165 dtype = tf .float64 , size = num_iters , element_shape = [parallelism ])
166166 # Iteration counter.
167167 num = tf .constant (0 , dtype = tf .int32 , name = 'num' )
168168 # TensorFlow constant to be used at multiple places.
169- val_53 = tf .constant (53 , tf .int64 , name = 'val_53' )
169+ val_53 = tf .constant (53 , tf .uint64 , name = 'val_53' )
170170
171171 # Construct initial sequence of seeds.
172172 # From a single input seed, we construct multiple starting seeds for the
173173 # sequences to be computed in parallel.
174174 def next_seed_fn (i , val , q ):
175- val = val ** 7 + val ** 6 + 1 # PRBS7.
175+ """Generates the next seed using a 7-bit LFSR.
176+
177+ This function implements a proper 7-bit Fibonacci LFSR with the polynomial
178+ x^7 + x^6 + 1. It takes the lower 7 bits of `val` as the current state,
179+ computes the next state, and writes it to the TensorArray `q`.
180+
181+ Args:
182+ i: The current index in the while loop.
183+ val: The current seed value (tf.uint64). The lower 7 bits are used as
184+ the LFSR state.
185+ q: The tf.TensorArray to write the generated seed into.
186+
187+ Returns:
188+ A tuple of (i + 1, new_val, q), where `new_val` is the next state of the
189+ LFSR.
190+ """
191+ state = tf .bitwise .bitwise_and (val , tf .constant (0x7F , tf .uint64 ))
192+ # Avoid zero state, which is a trapping state for this LFSR polynomial.
193+ state = tf .bitwise .bitwise_or (
194+ state ,
195+ tf .cast (tf .equal (state , tf .constant (0 , tf .uint64 )), tf .uint64 )
196+ )
197+ # Feedback bit = bit 7 (index 6) ^ bit 6 (index 5)
198+ feedback = tf .bitwise .bitwise_and (
199+ tf .bitwise .bitwise_xor (
200+ tf .bitwise .right_shift (state , tf .constant (6 , tf .uint64 )),
201+ tf .bitwise .right_shift (state , tf .constant (5 , tf .uint64 ))
202+ ),
203+ tf .constant (1 , tf .uint64 )
204+ )
205+ # Shift left and insert feedback
206+ val = tf .bitwise .bitwise_and (
207+ tf .bitwise .bitwise_or (
208+ tf .bitwise .left_shift (state , tf .constant (1 , tf .uint64 )),
209+ feedback
210+ ),
211+ tf .constant (0x7F , tf .uint64 )
212+ )
176213 q = q .write (i , val )
177214 return i + 1 , val , q
178215
179- q = tf .TensorArray (dtype = tf .int64 , size = parallelism , element_shape = ())
216+ q = tf .TensorArray (dtype = tf .uint64 , size = parallelism , element_shape = ())
217+ seed_u64 = tf .cast (seed , tf .uint64 )
180218 _ , _ , q = tf .while_loop (lambda i , _ , __ : i < parallelism ,
181219 next_seed_fn ,
182- [tf .constant (0 ), seed , q ])
220+ [tf .constant (0 ), seed_u64 , q ])
183221 c = q = q .stack ()
184222
185223 # The random sequence generation code.
@@ -193,9 +231,10 @@ def cmwc_step(f, bits, q, c, num, values):
193231 f .set_shape ((1 ,)) # Correct for failed shape inference.
194232 bits += logb_scalar
195233 def add_val (bits , f , values , num ):
234+ mask_53 = tf .constant (2 ** 53 - 1 , tf .uint64 )
196235 new_val = tf .cast (
197- tf .bitwise .bitwise_and (f , ( 2 ** val_53 - 1 ) ),
198- dtype = tf .float64 ) * (1 / 2 ** val_53 )
236+ tf .bitwise .bitwise_and (f , mask_53 ),
237+ dtype = tf .float64 ) * (1.0 / 2.0 ** 53 )
199238 values = values .write (num , new_val )
200239 f += tf .bitwise .right_shift (f , val_53 )
201240 bits -= val_53
0 commit comments