Skip to content

Commit f9abc8f

Browse files
authored
Add beam search decoding util (#237)
* beam search * minor fixes * minor changes * minor style changes * Fixed bug * naming and docstring updates * temporary change in setup * undo setup change * updated init to export utils * init change * style changes * converted to loop based beam search * docstring description changes
1 parent 2b69891 commit f9abc8f

File tree

3 files changed

+309
-4
lines changed

3 files changed

+309
-4
lines changed

Diff for: keras_nlp/utils/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from keras_nlp.utils.text_generation import beam_search
1516
from keras_nlp.utils.text_generation import greedy_search
1617
from keras_nlp.utils.text_generation import random_search
1718
from keras_nlp.utils.text_generation import top_k_search

Diff for: keras_nlp/utils/text_generation.py

+152-2
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,156 @@ def token_probability_fn(inputs):
160160
return prompt
161161

162162

163+
def beam_search(
164+
token_probability_fn,
165+
prompt,
166+
max_length,
167+
num_beams,
168+
from_logits=False,
169+
end_token_id=None,
170+
pad_token_id=0,
171+
):
172+
"""
173+
Text generation utility based on beam search algorithm.
174+
175+
At each time-step, beam search keeps the beams (sequences) of the top
176+
`num_beams` highest accumulated probabilities, and uses each one of the
177+
beams to predict candidate next tokens.
178+
179+
Args:
180+
token_probability_fn: a callable, which takes in input_sequence
181+
and output the probability distribution of the next token. If
182+
`from_logits` set to True, it should output the logits of the next
183+
token. The input shape would be `[batch_size, length]` and the
184+
output should be `[batch_size, vocab_size]`, where batch_size is
185+
variable.
186+
prompt: a list or a Tensor, can be 1D or 2D, the initial tokens to
187+
append generated tokens. The initial beam for beam search.
188+
max_length: int. The max length of generated text.
189+
num_beams: int. The number of beams that should be kept at each
190+
time-step. `num_beams` should be strictly positive.
191+
from_logits: bool. Indicates whether `token_probability_fn` outputs
192+
logits or probabilities.
193+
end_token_id: int, defaults to None. The token marking the end of the
194+
sequence, once encountered the generation is finished for the exact
195+
sequence. If None, every sequence is generated up to `max_length`.
196+
If set, all tokens after encountering `end_token_id` will be
197+
replaced with `pad_token_id`.
198+
pad_token_id: int, defaults to 0. The pad token after `end_token_id`
199+
is received.
200+
201+
Returns:
202+
A 1D int Tensor, or 2D int Tensor representing the generated
203+
sequences.
204+
205+
Examples:
206+
```python
207+
BATCH_SIZE = 8
208+
VOCAB_SIZE = 10
209+
FEATURE_SIZE = 16
210+
START_ID = 1
211+
END_ID = 2
212+
213+
# Create a dummy model to predict the next token.
214+
model = tf.keras.Sequential(
215+
[
216+
tf.keras.Input(shape=[None]),
217+
tf.keras.layers.Embedding(
218+
input_dim=VOCAB_SIZE,
219+
output_dim=FEATURE_SIZE,
220+
),
221+
tf.keras.layers.Dense(VOCAB_SIZE, activation="softmax"),
222+
]
223+
)
224+
225+
# Define a function that outputs the next token's probability given the
226+
# input sequence.
227+
def token_probability_fn(inputs):
228+
return model(inputs)[:, -1, :]
229+
230+
prompt = tf.fill((BATCH_SIZE, 1), START_ID)
231+
232+
# Print the generated sequence (token ids).
233+
keras_nlp.utils.beam_search(
234+
token_probability_fn,
235+
prompt,
236+
max_length=10,
237+
num_beams=5,
238+
end_token_id=END_ID,
239+
)
240+
```
241+
242+
"""
243+
if not tf.executing_eagerly():
244+
raise RuntimeError(
245+
"`keras_nlp.utils.beam_search` currently requires an eager "
246+
"execution context. Please call `beam_search` outside "
247+
"tf.function or run `tf.config.run_functions_eagerly(True)` to run "
248+
"tf.function in eager mode."
249+
)
250+
if num_beams <= 0:
251+
raise ValueError(
252+
f"`num_beams` should be strictly positive. Received: `num_beams={num_beams}`."
253+
)
254+
255+
prompt = validate_prompt(prompt)
256+
257+
input_is_1d = prompt.shape.rank == 1
258+
if input_is_1d:
259+
prompt = prompt[tf.newaxis, :]
260+
validate_token_probability_fn(token_probability_fn, prompt)
261+
262+
batch_size, length = prompt.shape
263+
if length < max_length:
264+
# Initialize beam.
265+
beams = tf.expand_dims(prompt, 1)
266+
beams_prob = tf.zeros([batch_size, 1])
267+
i = length
268+
while i < max_length:
269+
beam_size = beams.shape[1]
270+
beam_preds = []
271+
for j in range(beam_size):
272+
preds = token_probability_fn(beams[:, j, :])
273+
if from_logits:
274+
preds = tf.keras.activations.softmax(preds, axis=-1)
275+
beam_preds.append(preds)
276+
stacked_preds = tf.stack(beam_preds, axis=1)
277+
vocab_size = stacked_preds.shape[2]
278+
logits = tf.reshape(
279+
stacked_preds, [batch_size, beam_size * vocab_size]
280+
)
281+
probs = tf.math.log(logits) + tf.repeat(
282+
beams_prob, repeats=vocab_size, axis=1
283+
)
284+
num_beams = min(beam_size * vocab_size, num_beams)
285+
candidate_prob, candidate_indexes = tf.math.top_k(
286+
probs, k=num_beams, sorted=False
287+
)
288+
candidate_beam_indexes = candidate_indexes // vocab_size
289+
next_token = candidate_indexes % vocab_size
290+
291+
beams = tf.gather(
292+
beams, candidate_beam_indexes, axis=1, batch_dims=1
293+
)
294+
beams = tf.concat([beams, next_token[..., tf.newaxis]], axis=-1)
295+
beams_prob = candidate_prob
296+
i += 1
297+
# Get the beam with the maximum probability.
298+
max_indexes = tf.math.argmax(beams_prob, axis=-1)
299+
max_beams = tf.gather(
300+
beams, max_indexes[:, tf.newaxis], axis=1, batch_dims=1
301+
)
302+
prompt = tf.squeeze(max_beams)
303+
304+
if end_token_id is not None:
305+
prompt = mask_tokens_after_end_token(
306+
prompt, max_length, end_token_id, pad_token_id
307+
)
308+
if input_is_1d:
309+
return tf.squeeze(prompt)
310+
return prompt
311+
312+
163313
def random_search(
164314
token_probability_fn,
165315
prompt,
@@ -361,7 +511,7 @@ def token_probability_fn(inputs):
361511
"tf.function in eager mode."
362512
)
363513
if k <= 0:
364-
raise ValueError(f"`k` should strictly positive. Received: `k={k}`.")
514+
raise ValueError(f"`k` should be strictly positive. Received: `k={k}`.")
365515

366516
prompt = validate_prompt(prompt)
367517
input_is_1d = prompt.shape.rank == 1
@@ -378,7 +528,7 @@ def token_probability_fn(inputs):
378528
# If k is greater than the vocabulary size, use the entire vocabulary.
379529
k = min(k, pred.shape[1])
380530
# Filter out top-k tokens.
381-
top_k_pred, top_k_indices = tf.math.top_k(pred, k=k)
531+
top_k_pred, top_k_indices = tf.math.top_k(pred, k=k, sorted=False)
382532
# Sample the next token from the probability distribution.
383533
next_token = tf.random.categorical(
384534
tf.math.log(top_k_pred), 1, seed=seed

Diff for: keras_nlp/utils/text_generation_test.py

+156-2
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,12 @@
1313
# limitations under the License.
1414
"""Tests for Text Generation Utils."""
1515

16+
import random
1617

1718
import numpy as np
1819
import tensorflow as tf
1920

21+
from keras_nlp.utils.text_generation import beam_search
2022
from keras_nlp.utils.text_generation import greedy_search
2123
from keras_nlp.utils.text_generation import random_search
2224
from keras_nlp.utils.text_generation import top_k_search
@@ -111,6 +113,160 @@ def token_probability_fn(inputs):
111113
self.assertAllEqual(outputs, expected_outputs)
112114

113115

116+
class BeamSearchTextGenerationTest(tf.test.TestCase):
117+
def setUp(self):
118+
super().setUp()
119+
vocab_size = 10
120+
feature_size = 16
121+
122+
# Create a dummy model to predict the next token.
123+
model = tf.keras.Sequential(
124+
[
125+
tf.keras.Input(shape=[None]),
126+
tf.keras.layers.Embedding(
127+
input_dim=vocab_size,
128+
output_dim=feature_size,
129+
),
130+
tf.keras.layers.Dense(vocab_size),
131+
tf.keras.layers.Softmax(),
132+
]
133+
)
134+
135+
def token_probability_fn(inputs):
136+
return model(inputs)[:, -1, :]
137+
138+
self.token_probability_fn = token_probability_fn
139+
140+
def test_generate_with_empty_prompt(self):
141+
inputs = tf.constant([])
142+
with self.assertRaises(ValueError):
143+
beam_search(
144+
self.token_probability_fn, inputs, max_length=5, num_beams=5
145+
)
146+
inputs = tf.constant([[]])
147+
with self.assertRaises(ValueError):
148+
beam_search(
149+
self.token_probability_fn, inputs, max_length=5, num_beams=5
150+
)
151+
152+
def test_generate_with_1d_prompt(self):
153+
inputs = tf.constant([1])
154+
outputs = beam_search(
155+
self.token_probability_fn,
156+
inputs,
157+
max_length=5,
158+
num_beams=5,
159+
)
160+
self.assertEquals(outputs.shape, [5])
161+
162+
def test_generate_with_2d_prompt(self):
163+
inputs = tf.constant([[1], [1]])
164+
outputs = beam_search(
165+
self.token_probability_fn,
166+
inputs,
167+
max_length=5,
168+
num_beams=5,
169+
)
170+
self.assertEquals(outputs.shape, [2, 5])
171+
172+
def test_generate_with_list_prompt(self):
173+
inputs = [[1], [1]]
174+
outputs = beam_search(
175+
self.token_probability_fn,
176+
inputs,
177+
max_length=5,
178+
num_beams=5,
179+
)
180+
self.assertEquals(outputs.shape, [2, 5])
181+
182+
def test_generate_with_ragged_prompt(self):
183+
inputs = tf.ragged.constant([[1], [2, 3]])
184+
with self.assertRaises(ValueError):
185+
beam_search(
186+
self.token_probability_fn,
187+
inputs,
188+
max_length=5,
189+
num_beams=5,
190+
)
191+
192+
def test_one_beam_generation(self):
193+
for i in range(50):
194+
inputs = tf.constant([random.randint(0, 9)])
195+
beam_output = beam_search(
196+
self.token_probability_fn,
197+
inputs,
198+
max_length=5,
199+
num_beams=1,
200+
)
201+
greedy_output = greedy_search(
202+
self.token_probability_fn,
203+
inputs,
204+
max_length=5,
205+
)
206+
self.assertAllEqual(beam_output, greedy_output)
207+
208+
def test_multiple_beam_generation(self):
209+
def token_probability_fn(inputs):
210+
if inputs.shape[1] == 1:
211+
prob = tf.constant([[0.1, 0.2, 0.3, 0.4]])
212+
elif inputs[0][1] == 2:
213+
prob = tf.constant([[0.9, 0.08, 0.01, 0.01]])
214+
elif inputs[0][1] == 3:
215+
prob = tf.constant([[0.25, 0.25, 0.25, 0.25]])
216+
return prob
217+
218+
inputs = tf.constant([[1]])
219+
beam_output = beam_search(
220+
token_probability_fn,
221+
inputs,
222+
max_length=3,
223+
num_beams=2,
224+
)
225+
self.assertAllEqual(
226+
beam_output, tf.constant([1, 2, 0], dtype=beam_output.dtype)
227+
)
228+
229+
def test_assert_generation_is_correct(self):
230+
def token_probability_fn(inputs):
231+
batch_size = inputs.shape[0]
232+
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
233+
return tf.repeat(prob, batch_size, axis=0)
234+
235+
batch_size = 10
236+
inputs = 3 * tf.ones([batch_size, 1], dtype=tf.int32)
237+
max_length = 3
238+
for i in range(1, 10):
239+
outputs = beam_search(
240+
token_probability_fn,
241+
inputs,
242+
max_length=max_length,
243+
num_beams=i,
244+
)
245+
self.assertAllEqual(
246+
outputs, 3 * tf.ones(shape=[batch_size, max_length])
247+
)
248+
249+
def test_end_token_id(self):
250+
def token_probability_fn(inputs):
251+
batch_size = inputs.shape[0]
252+
prob = tf.constant([[0.01, 0.01, 0.08, 0.9]])
253+
return tf.repeat(prob, batch_size, axis=0)
254+
255+
max_length = 5
256+
inputs = tf.constant([[0, 1], [1, 2]])
257+
outputs = beam_search(
258+
token_probability_fn,
259+
inputs,
260+
max_length=max_length,
261+
num_beams=2,
262+
end_token_id=2,
263+
pad_token_id=0,
264+
)
265+
expected_outputs = tf.tile([[3], [0]], [1, max_length - 2])
266+
expected_outputs = tf.concat([inputs, expected_outputs], axis=1)
267+
self.assertAllEqual(outputs, expected_outputs)
268+
269+
114270
class RandomSearchTextGenerationTest(tf.test.TestCase):
115271
def setUp(self):
116272
super().setUp()
@@ -334,8 +490,6 @@ def token_probability_fn(inputs):
334490
)
335491
# Top-k sampling result with seed 42.
336492
seeded_result = 3 * np.ones(shape=[batch_size, max_length])
337-
seeded_result[3][1] = 2
338-
seeded_result[7][1] = 2
339493
self.assertAllEqual(outputs, seeded_result)
340494

341495
def test_assert_probability_distribution_generation_is_correct(self):

0 commit comments

Comments
 (0)