Skip to content

Add a dedicated padding token to beam search to avoid padding with the start sentence token #10386

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 1 commit into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions official/nlp/modeling/ops/beam_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def __init__(self,
max_decode_length,
eos_id,
padded_decode,
pad_id,
dtype=tf.float32):
"""Initialize sequence beam search.

Expand All @@ -128,6 +129,7 @@ def __init__(self,
eos_id: An integer. ID of end of sentence token.
padded_decode: A bool, indicating if max_sequence_length padding is used
for beam search.
pad_id: An integer, ID to be used to pad predictions.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.
"""
Expand All @@ -138,6 +140,7 @@ def __init__(self,
self.max_decode_length = max_decode_length
self.eos_id = eos_id
self.padded_decode = padded_decode
self.pad_id = pad_id
self.dtype = tf.as_dtype(dtype)

def search(self, initial_ids, initial_cache):
Expand Down Expand Up @@ -409,7 +412,7 @@ def _create_initial_state(self, initial_ids, initial_cache, batch_size):
alive_seq = expand_to_beam_size(initial_ids, self.beam_size)
alive_seq = tf.expand_dims(alive_seq, axis=2)
if self.padded_decode:
alive_seq = tf.tile(alive_seq, [1, 1, self.max_decode_length + 1])
alive_seq = tf.pad(alive_seq, [[0, 0], [0, 0], [0, self.max_decode_length]], constant_values=self.pad_id)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We need to remove the extra spaces in front of this line to conform to our Python styling

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think need to resolve the conflict as well, will create a new PR with all above changes.


# Create tensor for storing initial log probabilities.
# Assume initial_ids are prob 1.0
Expand Down Expand Up @@ -587,6 +590,7 @@ def sequence_beam_search(symbols_to_logits_fn,
max_decode_length,
eos_id,
padded_decode=False,
pad_id=0,
dtype="float32"):
"""Search for sequence of subtoken ids with the largest probability.

Expand All @@ -610,6 +614,7 @@ def sequence_beam_search(symbols_to_logits_fn,
finished.
padded_decode: A bool, indicating if max_sequence_length padding is used for
beam search.
pad_id: An integer, ID to be used to pad predictions.
dtype: A tensorflow data type used for score computation. The default is
tf.float32.

Expand All @@ -618,7 +623,7 @@ def sequence_beam_search(symbols_to_logits_fn,
sequence scores [batch_size, beam_size]
"""
sbs = SequenceBeamSearch(symbols_to_logits_fn, vocab_size, beam_size, alpha,
max_decode_length, eos_id, padded_decode, dtype)
max_decode_length, eos_id, padded_decode, pad_id, dtype)
return sbs.search(initial_ids, initial_cache)


Expand Down
30 changes: 30 additions & 0 deletions official/nlp/modeling/ops/beam_search_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,36 @@ def symbols_to_logits_fn(_, i, cache):
dtype=tf.float32)
self.assertAllEqual([[[0, 1, 0, 1], [0, 1, 1, 2]]], predictions)

def test_sequence_beam_search_with_truncated_prediction(self):
# batch_size*beam_size, max_decode_length, vocab_size
probabilities = tf.constant([[[0.2, 0.7, 0.1], [0.2, 0.3, 0.5], [0.1, 0.8, 0.1]]])
# batch_size, max_decode_length, num_heads, embed_size per head
x = tf.zeros([1, 3, 2, 32], dtype=tf.float32)
cache = {'layer_%d' % layer: {'k': x, 'v': x} for layer in range(2)}

def _get_test_symbols_to_logits_fn():
"""Test function that returns logits for next token."""

def symbols_to_logits_fn(_, i, cache):
logits = tf.cast(probabilities[:, i, :], tf.float32)
return logits, cache

return symbols_to_logits_fn

predictions, _ = beam_search.sequence_beam_search(
symbols_to_logits_fn=_get_test_symbols_to_logits_fn(),
initial_ids=tf.zeros([1], dtype=tf.int32),
initial_cache=cache,
vocab_size=3,
beam_size=1,
alpha=0.6,
max_decode_length=3,
eos_id=2,
padded_decode=True,
pad_id=42,
dtype=tf.float32)
self.assertAllEqual([[[0, 1, 2, 42]]], predictions)


if __name__ == '__main__':
tf.test.main()