diff --git a/keras_hub/src/samplers/random_sampler.py b/keras_hub/src/samplers/random_sampler.py index 368f7ca71e..589c58700c 100644 --- a/keras_hub/src/samplers/random_sampler.py +++ b/keras_hub/src/samplers/random_sampler.py @@ -45,8 +45,10 @@ def __init__( def get_next_token(self, probabilities): # Sample the next token from the probability distribution. + # tf does not support half precision multinomial sampling, so make + # sure we have full precision here. next_token_id = random.categorical( - ops.log(probabilities), + ops.log(ops.cast(probabilities, "float32")), 1, seed=self.seed_generator, dtype="int32", diff --git a/keras_hub/src/samplers/top_p_sampler.py b/keras_hub/src/samplers/top_p_sampler.py index 4477acaf77..81ff4b35b6 100644 --- a/keras_hub/src/samplers/top_p_sampler.py +++ b/keras_hub/src/samplers/top_p_sampler.py @@ -79,7 +79,9 @@ def get_next_token(self, probabilities): ops.zeros(ops.shape(sorted_preds), dtype=sorted_preds.dtype), ) sorted_next_token = random.categorical( - ops.log(probabilities), + # tf does not support half precision multinomial sampling, so make + # sure we have full precision here. + ops.log(ops.cast(probabilities, "float32")), 1, seed=self.seed_generator, dtype="int32",