Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions keras_hub/api/layers/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,3 @@
from keras_hub.src.layers.preprocessing.v2.multi_segment_packer import (
MultiSegmentPacker as MultiSegmentPacker,
)
from keras_hub.src.layers.preprocessing.v2.start_end_packer import (
StartEndPacker as StartEndPacker,
)
3 changes: 0 additions & 3 deletions keras_hub/api/tokenizers/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
since your modifications would be overwritten.
"""

from keras_hub.src.tokenizers.v2.byte_pair_tokenizer import (
BytePairTokenizer as BytePairTokenizer,
)
from keras_hub.src.tokenizers.v2.sentence_piece_tokenizer import (
SentencePieceTokenizer as SentencePieceTokenizer,
)
7 changes: 6 additions & 1 deletion keras_hub/src/layers/preprocessing/preprocessing_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@ class PreprocessingLayer(keras.layers.Layer):
"""Preprocessing layer base class."""

def __init__(self, **kwargs):
assert_tf_libs_installed(self.__class__.__name__)
_allow_python_workflow = kwargs.pop("_allow_python_workflow", False)
if not _allow_python_workflow:
assert_tf_libs_installed(self.__class__.__name__)
super().__init__(**kwargs)
# Don't convert inputs (we want tf tensors not backend tensors).
self._convert_input_args = False
# Allow raw inputs like python strings.
self._allow_non_tensor_positional_args = True
# Allow Python workflow. Historically, KerasHub preprocessing layers
# required TF and TF text libraries.
self._allow_python_workflow = _allow_python_workflow
# Most pre-preprocessing has no build.
if not hasattr(self, "build"):
self.built = True
Expand Down
171 changes: 166 additions & 5 deletions keras_hub/src/layers/preprocessing/start_end_packer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import keras
import numpy as np

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_hub.src.utils.tensor_utils import (
convert_preprocessing_outputs_python,
)
from keras_hub.src.utils.tensor_utils import convert_to_list
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
from keras_hub.src.utils.tensor_utils import in_tf_function
from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function

Expand Down Expand Up @@ -74,7 +82,7 @@ class StartEndPacker(PreprocessingLayer):
[ 1, 8, 9, 10, 11, 2]], dtype=int32)

Unbatched input (str).
>>> inputs = tf.constant(["this", "is", "fun"])
>>> inputs = ["this", "is", "fun"]
>>> start_end_packer = keras_hub.layers.StartEndPacker(
... sequence_length=6, start_value="<s>", end_value="</s>",
... pad_value="<pad>"
Expand All @@ -84,7 +92,7 @@ class StartEndPacker(PreprocessingLayer):
array(['<s>', 'this', 'is', 'fun', '</s>', '<pad>'], dtype='<U5')

Batched input (str).
>>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
>>> inputs = [["this", "is", "fun"], ["awesome"]]
>>> start_end_packer = keras_hub.layers.StartEndPacker(
... sequence_length=6, start_value="<s>", end_value="</s>",
... pad_value="<pad>"
Expand All @@ -95,7 +103,7 @@ class StartEndPacker(PreprocessingLayer):
['<s>', 'awesome', '</s>', '<pad>', '<pad>', '<pad>']], dtype='<U7')

Multiple start tokens.
>>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
>>> inputs = [["this", "is", "fun"], ["awesome"]]
>>> start_end_packer = keras_hub.layers.StartEndPacker(
... sequence_length=6, start_value=["</s>", "<s>"], end_value="</s>",
... pad_value="<pad>"
Expand All @@ -117,7 +125,10 @@ def __init__(
padding_side="right",
**kwargs,
):
super().__init__(name=name, **kwargs)
_allow_python_workflow = kwargs.pop("_allow_python_workflow", True)
super().__init__(
name=name, _allow_python_workflow=_allow_python_workflow, **kwargs
)

self.sequence_length = sequence_length

Expand All @@ -126,6 +137,8 @@ def __init__(
self._end_value = end_value

def check_special_value_type(value, value_name):
if value is None:
return None
if isinstance(value, (int, str)):
return [value]
if value and not isinstance(value, (list, tuple)):
Expand All @@ -146,7 +159,7 @@ def check_special_value_type(value, value_name):
self.padding_side = padding_side

@preprocessing_function
def call(
def _call_tf(
self,
inputs,
sequence_length=None,
Expand Down Expand Up @@ -203,6 +216,152 @@ def call(
return outputs, mask
return outputs

def _call_python(
self,
inputs,
sequence_length=None,
add_start_value=True,
add_end_value=True,
):
def _canonicalize_inputs(inputs):
if isinstance(inputs, (tuple, list)):
inputs = keras.tree.map_structure(convert_to_list, inputs)
if inputs and isinstance(inputs[0], (tuple, list)):
return inputs, True
else:
return [inputs], False
elif tf is not None and isinstance(
inputs, (tf.Tensor, tf.RaggedTensor)
):
unbatched = inputs.shape.rank == 1
if unbatched:
inputs = tf.expand_dims(inputs, 0)
if isinstance(inputs, tf.Tensor):
inputs = convert_to_list(inputs)
else:
inputs = inputs.to_list()
return inputs, not unbatched
elif keras.ops.is_tensor(inputs):
inputs = convert_to_list(inputs)
if inputs and isinstance(inputs[0], (tuple, list)):
return inputs, True
else:
return [inputs], False
else:
raise ValueError(
f"Input should be a list or a list of lists. "
f"Received: {inputs}"
)

def _get_type(inputs):
for sequence in inputs:
if sequence is not None and len(sequence) > 0:
return type(sequence[0])
return int # Default to int if all sequences are empty.

def _canonicalize_value(values, input_type):
if input_type is str:
return [str(v) for v in values]
else:
return [int(v) for v in values]

def _pad(x, pad_value, padding_side, sequence_length, input_type=None):
if padding_side not in ("left", "right"):
raise ValueError(
"padding_side must be 'left' or 'right'. "
f"Received: {padding_side}"
)
if pad_value is None:
pad_value = "" if input_type is str else 0
if padding_side == "right":
x = [
seq + [pad_value] * (sequence_length - len(seq))
for seq in x
]
else:
x = [
[pad_value] * (sequence_length - len(seq)) + seq
for seq in x
]
return x

def _canonicalize_outputs(outputs, dtype=None):
flat_outputs = keras.tree.flatten(outputs)
if not flat_outputs:
return np.array(outputs, dtype=dtype or "int32")
first_element = flat_outputs[0]
if not isinstance(first_element, str):
return np.array(outputs, dtype=dtype or "int32")
else:
return outputs

inputs, batched = _canonicalize_inputs(inputs)
input_type = _get_type(inputs)
sequence_length = sequence_length or self.sequence_length
x = inputs

# Truncate and normalize to list of lists.
truncation_length = sequence_length
if add_start_value and self.start_value is not None:
truncation_length -= len(self.start_value)
if add_end_value and self.end_value is not None:
truncation_length -= len(self.end_value)
x = [list(seq)[:truncation_length] for seq in x]

# Concatenate start and end tokens.
if add_start_value and self.start_value is not None:
start_value = _canonicalize_value(self.start_value, input_type)
x = [start_value + seq for seq in x]
if add_end_value and self.end_value is not None:
end_value = _canonicalize_value(self.end_value, input_type)
x = [seq + end_value for seq in x]

# Pad to desired length.
outputs = _pad(
x,
pad_value=self.pad_value,
padding_side=self.padding_side,
sequence_length=sequence_length,
input_type=input_type,
)
outputs = _canonicalize_outputs(outputs)
outputs = outputs[0] if not batched else outputs

if self.return_padding_mask:
masks = keras.tree.map_structure(lambda _: True, x)
masks = _pad(
masks,
pad_value=False,
padding_side=self.padding_side,
sequence_length=sequence_length,
)
masks = masks[0] if not batched else masks
masks = _canonicalize_outputs(masks, dtype="bool")
return convert_preprocessing_outputs_python((outputs, masks))
return convert_preprocessing_outputs_python(outputs)

def call(
self,
inputs,
sequence_length=None,
add_start_value=True,
add_end_value=True,
):
if not self._allow_python_workflow or in_tf_function():
return self._call_tf(
inputs,
sequence_length=sequence_length,
add_start_value=add_start_value,
add_end_value=add_end_value,
)
else:
return self._call_python(
inputs,
sequence_length=sequence_length,
add_start_value=add_start_value,
add_end_value=add_end_value,
)

def get_config(self):
config = super().get_config()
config.update(
Expand All @@ -220,4 +379,6 @@ def get_config(self):
def compute_output_shape(self, inputs_shape):
inputs_shape = list(inputs_shape)
inputs_shape[-1] = self.sequence_length
if self.return_padding_mask:
return tuple(inputs_shape), tuple(inputs_shape)
return tuple(inputs_shape)
Loading
Loading