-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathrandom_sampler.py
More file actions
65 lines (53 loc) · 1.81 KB
/
random_sampler.py
File metadata and controls
65 lines (53 loc) · 1.81 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from keras import ops
from keras import random
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.samplers.sampler import Sampler
@keras_hub_export("keras_hub.samplers.RandomSampler")
class RandomSampler(Sampler):
"""Random Sampler class.
This sampler implements random sampling. Briefly, random sampler randomly
selects a token from the entire distribution of the tokens, with selection
chance determined by the probability of each token.
Args:
seed: int. The random seed. Defaults to `None`.
Call arguments:
{{call_args}}
Examples:
```python
causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Pass by name to compile.
causal_lm.compile(sampler="random")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_hub.samplers.RandomSampler(temperature=0.7)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""
def __init__(
self,
seed=None,
**kwargs,
):
super().__init__(**kwargs)
self.seed = seed
self.seed_generator = random.SeedGenerator(seed)
def get_next_token(self, probabilities):
# Sample the next token from the probability distribution.
# tf does not support half precision multinomial sampling, so make
# sure we have full precision here.
next_token_id = random.categorical(
ops.log(ops.cast(probabilities, "float32")),
1,
seed=self.seed_generator,
dtype="int32",
)
return ops.squeeze(next_token_id, axis=-1)
def get_config(self):
config = super().get_config()
config.update(
{
"seed": self.seed,
}
)
return config