@@ -160,6 +160,156 @@ def token_probability_fn(inputs):
160
160
return prompt
161
161
162
162
163
+ def beam_search (
164
+ token_probability_fn ,
165
+ prompt ,
166
+ max_length ,
167
+ num_beams ,
168
+ from_logits = False ,
169
+ end_token_id = None ,
170
+ pad_token_id = 0 ,
171
+ ):
172
+ """
173
+ Text generation utility based on beam search algorithm.
174
+
175
+ At each time-step, beam search keeps the beams (sequences) of the top
176
+ `num_beams` highest accumulated probabilities, and uses each one of the
177
+ beams to predict candidate next tokens.
178
+
179
+ Args:
180
+ token_probability_fn: a callable, which takes in input_sequence
181
+ and output the probability distribution of the next token. If
182
+ `from_logits` set to True, it should output the logits of the next
183
+ token. The input shape would be `[batch_size, length]` and the
184
+ output should be `[batch_size, vocab_size]`, where batch_size is
185
+ variable.
186
+ prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
187
+ append generated tokens. The initial beam for beam search.
188
+ max_length: int. The max length of generated text.
189
+ num_beams: int. The number of beams that should be kept at each
190
+ time-step. `num_beams` should be strictly positive.
191
+ from_logits: bool. Indicates whether `token_probability_fn` outputs
192
+ logits or probabilities.
193
+ end_token_id: int, defaults to None. The token marking the end of the
194
+ sequence, once encountered the generation is finished for the exact
195
+ sequence. If None, every sequence is generated up to `max_length`.
196
+ If set, all tokens after encountering `end_token_id` will be
197
+ replaced with `pad_token_id`.
198
+ pad_token_id: int, defaults to 0. The pad token after `end_token_id`
199
+ is received.
200
+
201
+ Returns:
202
+ A 1D int Tensor, or 2D int Tensor representing the generated
203
+ sequences.
204
+
205
+ Examples:
206
+ ```python
207
+ BATCH_SIZE = 8
208
+ VOCAB_SIZE = 10
209
+ FEATURE_SIZE = 16
210
+ START_ID = 1
211
+ END_ID = 2
212
+
213
+ # Create a dummy model to predict the next token.
214
+ model = tf.keras.Sequential(
215
+ [
216
+ tf.keras.Input(shape=[None]),
217
+ tf.keras.layers.Embedding(
218
+ input_dim=VOCAB_SIZE,
219
+ output_dim=FEATURE_SIZE,
220
+ ),
221
+ tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
222
+ ]
223
+ )
224
+
225
+ # Define a function that outputs the next token's probability given the
226
+ # input sequence.
227
+ def token_probability_fn(inputs):
228
+ return model(inputs)[:, -1, :]
229
+
230
+ prompt = tf.fill((BATCH_SIZE, 1), START_ID)
231
+
232
+ # Print the generated sequence (token ids).
233
+ keras_nlp.utils.beam_search(
234
+ token_probability_fn,
235
+ prompt,
236
+ max_length=10,
237
+ num_beams=5,
238
+ end_token_id=END_ID,
239
+ )
240
+ ```
241
+
242
+ """
243
+ if not tf .executing_eagerly ():
244
+ raise RuntimeError (
245
+ "`keras_nlp.utils.beam_search` currently requires an eager "
246
+ "execution context. Please call `beam_search` outside "
247
+ "tf.function or run `tf.config.run_functions_eagerly(True)` to run "
248
+ "tf.function in eager mode."
249
+ )
250
+ if num_beams <= 0 :
251
+ raise ValueError (
252
+ f"`num_beams` should be strictly positive. Received: `num_beams={ num_beams } `."
253
+ )
254
+
255
+ prompt = validate_prompt (prompt )
256
+
257
+ input_is_1d = prompt .shape .rank == 1
258
+ if input_is_1d :
259
+ prompt = prompt [tf .newaxis , :]
260
+ validate_token_probability_fn (token_probability_fn , prompt )
261
+
262
+ batch_size , length = prompt .shape
263
+ if length < max_length :
264
+ # Initialize beam.
265
+ beams = tf .expand_dims (prompt , 1 )
266
+ beams_prob = tf .zeros ([batch_size , 1 ])
267
+ i = length
268
+ while i < max_length :
269
+ beam_size = beams .shape [1 ]
270
+ beam_preds = []
271
+ for j in range (beam_size ):
272
+ preds = token_probability_fn (beams [:, j , :])
273
+ if from_logits :
274
+ preds = tf .keras .activations .softmax (preds , axis = - 1 )
275
+ beam_preds .append (preds )
276
+ stacked_preds = tf .stack (beam_preds , axis = 1 )
277
+ vocab_size = stacked_preds .shape [2 ]
278
+ logits = tf .reshape (
279
+ stacked_preds , [batch_size , beam_size * vocab_size ]
280
+ )
281
+ probs = tf .math .log (logits ) + tf .repeat (
282
+ beams_prob , repeats = vocab_size , axis = 1
283
+ )
284
+ num_beams = min (beam_size * vocab_size , num_beams )
285
+ candidate_prob , candidate_indexes = tf .math .top_k (
286
+ probs , k = num_beams , sorted = False
287
+ )
288
+ candidate_beam_indexes = candidate_indexes // vocab_size
289
+ next_token = candidate_indexes % vocab_size
290
+
291
+ beams = tf .gather (
292
+ beams , candidate_beam_indexes , axis = 1 , batch_dims = 1
293
+ )
294
+ beams = tf .concat ([beams , next_token [..., tf .newaxis ]], axis = - 1 )
295
+ beams_prob = candidate_prob
296
+ i += 1
297
+ # Get the beam with the maximum probability.
298
+ max_indexes = tf .math .argmax (beams_prob , axis = - 1 )
299
+ max_beams = tf .gather (
300
+ beams , max_indexes [:, tf .newaxis ], axis = 1 , batch_dims = 1
301
+ )
302
+ prompt = tf .squeeze (max_beams )
303
+
304
+ if end_token_id is not None :
305
+ prompt = mask_tokens_after_end_token (
306
+ prompt , max_length , end_token_id , pad_token_id
307
+ )
308
+ if input_is_1d :
309
+ return tf .squeeze (prompt )
310
+ return prompt
311
+
312
+
163
313
def random_search (
164
314
token_probability_fn ,
165
315
prompt ,
@@ -361,7 +511,7 @@ def token_probability_fn(inputs):
361
511
"tf.function in eager mode."
362
512
)
363
513
if k <= 0 :
364
- raise ValueError (f"`k` should strictly positive. Received: `k={ k } `." )
514
+ raise ValueError (f"`k` should be strictly positive. Received: `k={ k } `." )
365
515
366
516
prompt = validate_prompt (prompt )
367
517
input_is_1d = prompt .shape .rank == 1
@@ -378,7 +528,7 @@ def token_probability_fn(inputs):
378
528
# If k is greater than the vocabulary size, use the entire vocabulary.
379
529
k = min (k , pred .shape [1 ])
380
530
# Filter out top-k tokens.
381
- top_k_pred , top_k_indices = tf .math .top_k (pred , k = k )
531
+ top_k_pred , top_k_indices = tf .math .top_k (pred , k = k , sorted = False )
382
532
# Sample the next token from the probability distribution.
383
533
next_token = tf .random .categorical (
384
534
tf .math .log (top_k_pred ), 1 , seed = seed
0 commit comments