-
Notifications
You must be signed in to change notification settings - Fork 330
Expand file tree
/
Copy pathtop_p_sampler.py
More file actions
101 lines (89 loc) · 3.46 KB
/
top_p_sampler.py
File metadata and controls
101 lines (89 loc) · 3.46 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
from keras import ops
from keras import random
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.samplers.sampler import Sampler
@keras_hub_export("keras_hub.samplers.TopPSampler")
class TopPSampler(Sampler):
"""Top-P Sampler class.
This sampler implements top-p search algorithm. Top-p search selects tokens
from the smallest subset of output probabilities that sum to greater than
`p`. Put in another way, top-p will first order token predictions by
likelihood, and ignore all tokens after the cumulative probability of
selected tokens exceeds `p`, then select a token from the remaining tokens.
Args:
p: float, the `p` value of top-p.
k: int. If set, this argument defines a
heuristic "top-k" cutoff applied before the "top-p" sampling. All
logits not in the top `k` will be discarded, and the remaining
logits will be sorted to find a cutoff point for `p`. Setting this
arg can significantly speed sampling up by reducing the number
of tokens to sort. Defaults to `None`.
seed: int. The random seed. Defaults to `None`.
Call arguments:
{{call_args}}
Examples:
```python
causal_lm = keras_hub.models.GPT2CausalLM.from_preset("gpt2_base_en")
# Pass by name to compile.
causal_lm.compile(sampler="top_p")
causal_lm.generate(["Keras is a"])
# Pass by object to compile.
sampler = keras_hub.samplers.TopPSampler(p=0.1, k=1_000)
causal_lm.compile(sampler=sampler)
causal_lm.generate(["Keras is a"])
```
"""
def __init__(
self,
p=0.1,
k=None,
seed=None,
**kwargs,
):
super().__init__(**kwargs)
self.p = p
self.k = k
self.seed = seed
self.seed_generator = random.SeedGenerator(seed)
def get_next_token(self, probabilities):
cutoff = ops.shape(probabilities)[1]
if self.k is not None:
# If `k` is set, only sample from top `k` tokens.
cutoff = self.k
sorted_preds, sorted_indices = ops.top_k(
probabilities, k=cutoff, sorted=True
)
# Calculate cumulative probability distribution.
cumulative_probabilities = ops.cumsum(sorted_preds, axis=-1)
# Create a mask for the tokens to keep.
keep_mask = cumulative_probabilities <= self.p
# Shift to include the last token that exceed p.
shifted_keep_mask = ops.concatenate(
[ops.ones_like(keep_mask[:, :1]), keep_mask[:, :-1]], axis=-1
)
# Filter out unmasked tokens and sample from filtered distribution.
probabilities = ops.where(
shifted_keep_mask,
sorted_preds,
ops.zeros(ops.shape(sorted_preds), dtype=sorted_preds.dtype),
)
sorted_next_token = random.categorical(
# tf does not support half precision multinomial sampling, so make
# sure we have full precision here.
ops.cast(ops.log(probabilities), "float32"),
1,
seed=self.seed_generator,
dtype="int32",
)
output = ops.take_along_axis(sorted_indices, sorted_next_token, axis=-1)
return ops.squeeze(output, axis=-1)
def get_config(self):
config = super().get_config()
config.update(
{
"p": self.p,
"k": self.k,
"seed": self.seed,
}
)
return config