diff --git a/keras_hub/src/models/causal_lm.py b/keras_hub/src/models/causal_lm.py index 32e00fb858..1ef50ae4b4 100644 --- a/keras_hub/src/models/causal_lm.py +++ b/keras_hub/src/models/causal_lm.py @@ -8,6 +8,7 @@ from keras_hub.src.api_export import keras_hub_export from keras_hub.src.models.task import Task from keras_hub.src.samplers.serialization import get as get_sampler +from keras.src.distribution import distribution_lib try: import tensorflow as tf @@ -354,6 +355,16 @@ def preprocess(x): x, sequence_length=max_length ) + def shard(x): + distribution = distribution_lib.distribution() + if distribution is None: + return x + result = {} + for key, value in x.items(): + layout = distribution.get_data_layout(value.shape) + result[key] = distribution_lib.distribute_tensor(value, layout) if layout else value + return result + def generate(x): return generate_function(x, stop_token_ids=stop_token_ids) @@ -394,6 +405,8 @@ def postprocess(x): if self.preprocessor is not None: inputs = [preprocess(x) for x in inputs] + inputs = [shard(x) for x in inputs] + if strip_prompt: outputs = [strip_prompt_function(generate(x), x) for x in inputs] else: