Skip to content

Commit f9dab0a

Browse files
author
Flax Authors
committed
Merge pull request #2842 from cgarciae:fix-1322
PiperOrigin-RevId: 512840808
2 parents 60ab1cc + 2e1232e commit f9dab0a

File tree

2 files changed

+47
-2
lines changed

2 files changed

+47
-2
lines changed

examples/lm1b/temperature_sampler.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def temperature_sample(prompt_inputs,
5757
# initial loop PRNGKey
5858
rng0 = prng_key
5959
# loop position counter.
60-
i0 = jnp.array(0)
60+
i0 = jnp.array(-1)
6161
# per batch-item holding current token in loop.
6262
token0 = jnp.zeros((batch_size, 1), dtype=jnp.int32)
6363
# per batch-item state bit indicating if sentence has finished.
@@ -72,7 +72,7 @@ def sampling_loop_cond_fn(state):
7272
"""Sampling loop termination condition."""
7373
(i, _, _, _, ended, _) = state
7474
# Have we reached max decoding length?
75-
not_at_end = (i < max_decode_len)
75+
not_at_end = (i < max_decode_len - 1)
7676
# Have all sampled sequences reached an end marker?
7777
all_sequences_ended = jnp.all(ended)
7878
return not_at_end & (~all_sequences_ended)
+45
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
# Copyright 2022 The Flax Authors.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from absl.testing import absltest
16+
import jax
17+
import jax.numpy as jnp
18+
import numpy as np
19+
20+
from temperature_sampler import temperature_sample
21+
22+
23+
jax.config.update('jax_disable_most_optimizations', True)
24+
25+
26+
class TestTemperatureSampler(absltest.TestCase):
27+
def test_temperature_sampler(self):
28+
29+
tokens = jnp.array([[5, 0, 0, 0]], dtype=jnp.int32)
30+
cache = None
31+
key = jax.random.PRNGKey(0)
32+
33+
def tokens_to_logits(tokens, cache):
34+
jax.debug.print("tokens: {}", tokens)
35+
logits = jax.nn.one_hot(tokens[..., -1:] + 1, 10)
36+
logits = jnp.where(logits < 0.5, float('-inf'), logits)
37+
logits = logits.squeeze(axis=1)
38+
return logits, cache
39+
40+
new_tokens = temperature_sample(tokens, cache, tokens_to_logits, key, topk=5)
41+
42+
np.testing.assert_array_equal(new_tokens, [[5, 6, 7, 8]])
43+
44+
if __name__ == '__main__':
45+
absltest.main()

0 commit comments

Comments
 (0)