1
+ # Copyright 2022 The Flax Authors.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from absl .testing import absltest
16
+ import jax
17
+ import jax .numpy as jnp
18
+ import numpy as np
19
+
20
+ from temperature_sampler import temperature_sample
21
+
22
+
23
+ jax .config .update ('jax_disable_most_optimizations' , True )
24
+
25
+
26
+ class TestTemperatureSampler (absltest .TestCase ):
27
+ def test_temperature_sampler (self ):
28
+
29
+ tokens = jnp .array ([[5 , 0 , 0 , 0 ]], dtype = jnp .int32 )
30
+ cache = None
31
+ key = jax .random .PRNGKey (0 )
32
+
33
+ def tokens_to_logits (tokens , cache ):
34
+ jax .debug .print ("tokens: {}" , tokens )
35
+ logits = jax .nn .one_hot (tokens [..., - 1 :] + 1 , 10 )
36
+ logits = jnp .where (logits < 0.5 , float ('-inf' ), logits )
37
+ logits = logits .squeeze (axis = 1 )
38
+ return logits , cache
39
+
40
+ new_tokens = temperature_sample (tokens , cache , tokens_to_logits , key , topk = 5 )
41
+
42
+ np .testing .assert_array_equal (new_tokens , [[5 , 6 , 7 , 8 ]])
43
+
44
+ if __name__ == '__main__' :
45
+ absltest .main ()
0 commit comments