Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion keras_hub/src/samplers/random_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,10 @@ def __init__(

def get_next_token(self, probabilities):
# Sample the next token from the probability distribution.
# tf does not support half precision multinomial sampling, so make
# sure we have full precision here.
next_token_id = random.categorical(
ops.log(probabilities),
ops.cast(ops.log(probabilities), "float32"),
1,
seed=self.seed_generator,
dtype="int32",
Expand Down
4 changes: 3 additions & 1 deletion keras_hub/src/samplers/top_p_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,9 @@ def get_next_token(self, probabilities):
ops.zeros(ops.shape(sorted_preds), dtype=sorted_preds.dtype),
)
sorted_next_token = random.categorical(
ops.log(probabilities),
# tf does not support half precision multinomial sampling, so make
# sure we have full precision here.
ops.cast(ops.log(probabilities), "float32"),
1,
seed=self.seed_generator,
dtype="int32",
Expand Down
Loading