Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 0 additions & 3 deletions keras_hub/api/layers/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,3 @@
from keras_hub.src.layers.preprocessing.v2.multi_segment_packer import (
MultiSegmentPacker as MultiSegmentPacker,
)
from keras_hub.src.layers.preprocessing.v2.start_end_packer import (
StartEndPacker as StartEndPacker,
)
3 changes: 0 additions & 3 deletions keras_hub/api/tokenizers/v2/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,6 @@
since your modifications would be overwritten.
"""

from keras_hub.src.tokenizers.v2.byte_pair_tokenizer import (
BytePairTokenizer as BytePairTokenizer,
)
from keras_hub.src.tokenizers.v2.sentence_piece_tokenizer import (
SentencePieceTokenizer as SentencePieceTokenizer,
)
7 changes: 6 additions & 1 deletion keras_hub/src/layers/preprocessing/preprocessing_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,17 @@ class PreprocessingLayer(keras.layers.Layer):
"""Preprocessing layer base class."""

def __init__(self, **kwargs):
assert_tf_libs_installed(self.__class__.__name__)
_allow_python_workflow = kwargs.pop("_allow_python_workflow", False)
if not _allow_python_workflow:
assert_tf_libs_installed(self.__class__.__name__)
super().__init__(**kwargs)
# Don't convert inputs (we want tf tensors not backend tensors).
self._convert_input_args = False
# Allow raw inputs like python strings.
self._allow_non_tensor_positional_args = True
# Allow Python workflow. Historically, KerasHub preprocessing layers
# required TF and TF text libraries.
self._allow_python_workflow = _allow_python_workflow
# Most pre-preprocessing has no build.
if not hasattr(self, "build"):
self.built = True
Expand Down
160 changes: 159 additions & 1 deletion keras_hub/src/layers/preprocessing/start_end_packer.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,16 @@
import keras
import numpy as np

from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
from keras_hub.src.utils.tensor_utils import (
convert_preprocessing_outputs_python,
)
from keras_hub.src.utils.tensor_utils import convert_to_list
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
from keras_hub.src.utils.tensor_utils import in_tf_function
from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function

Expand Down Expand Up @@ -126,6 +134,8 @@ def __init__(
self._end_value = end_value

def check_special_value_type(value, value_name):
if value is None:
return None
if isinstance(value, (int, str)):
return [value]
if value and not isinstance(value, (list, tuple)):
Expand All @@ -146,7 +156,7 @@ def check_special_value_type(value, value_name):
self.padding_side = padding_side

@preprocessing_function
def call(
def _call_tf(
self,
inputs,
sequence_length=None,
Expand Down Expand Up @@ -203,6 +213,152 @@ def call(
return outputs, mask
return outputs

def _call_python(
self,
inputs,
sequence_length=None,
add_start_value=True,
add_end_value=True,
):
def _canonicalize_inputs(inputs):
if isinstance(inputs, (tuple, list)):
inputs = keras.tree.map_structure(convert_to_list, inputs)
if inputs and isinstance(inputs[0], (tuple, list)):
return inputs, True
else:
return [inputs], False
elif tf is not None and isinstance(
inputs, (tf.Tensor, tf.RaggedTensor)
):
unbatched = inputs.shape.rank == 1
if unbatched:
inputs = tf.expand_dims(inputs, 0)
if isinstance(inputs, tf.Tensor):
inputs = convert_to_list(inputs)
else:
inputs = inputs.to_list()
return inputs, not unbatched
elif keras.ops.is_tensor(inputs):
inputs = convert_to_list(inputs)
if inputs and isinstance(inputs[0], (tuple, list)):
return inputs, True
else:
return [inputs], False
else:
raise ValueError(
f"Input should be a list or a list of lists. "
f"Received: {inputs}"
)

def _get_type(inputs):
for sequence in inputs:
if sequence is not None and len(sequence) > 0:
return type(sequence[0])
return int # Default to int if all sequences are empty.

def _canonicalize_value(values, input_type):
if input_type is str:
return [str(v) for v in values]
else:
return [int(v) for v in values]

def _pad(x, pad_value, padding_side, sequence_length, input_type=None):
if padding_side not in ("left", "right"):
raise ValueError(
"padding_side must be 'left' or 'right'. "
f"Received: {padding_side}"
)
if pad_value is None:
pad_value = "" if input_type is str else 0
if padding_side == "right":
x = [
seq + [pad_value] * (sequence_length - len(seq))
for seq in x
]
else:
x = [
[pad_value] * (sequence_length - len(seq)) + seq
for seq in x
]
return x

def _canonicalize_outputs(outputs, dtype=None):
flat_outputs = keras.tree.flatten(outputs)
if not flat_outputs:
return np.array(outputs, dtype=dtype or "int32")
first_element = flat_outputs[0]
if not isinstance(first_element, str):
return np.array(outputs, dtype=dtype or "int32")
else:
return outputs

inputs, batched = _canonicalize_inputs(inputs)
input_type = _get_type(inputs)
sequence_length = sequence_length or self.sequence_length
x = inputs

# Truncate and normalize to list of lists.
truncation_length = sequence_length
if add_start_value and self.start_value is not None:
truncation_length -= len(self.start_value)
if add_end_value and self.end_value is not None:
truncation_length -= len(self.end_value)
x = [list(seq)[:truncation_length] for seq in x]

# Concatenate start and end tokens.
if add_start_value and self.start_value is not None:
start_value = _canonicalize_value(self.start_value, input_type)
x = [start_value + seq for seq in x]
if add_end_value and self.end_value is not None:
end_value = _canonicalize_value(self.end_value, input_type)
x = [seq + end_value for seq in x]

# Pad to desired length.
outputs = _pad(
x,
pad_value=self.pad_value,
padding_side=self.padding_side,
sequence_length=sequence_length,
input_type=input_type,
)
outputs = _canonicalize_outputs(outputs)
outputs = outputs[0] if not batched else outputs

if self.return_padding_mask:
masks = keras.tree.map_structure(lambda _: True, x)
masks = _pad(
masks,
pad_value=False,
padding_side=self.padding_side,
sequence_length=sequence_length,
)
masks = masks[0] if not batched else masks
masks = _canonicalize_outputs(masks, dtype="bool")
return convert_preprocessing_outputs_python((outputs, masks))
return convert_preprocessing_outputs_python(outputs)

def call(
self,
inputs,
sequence_length=None,
add_start_value=True,
add_end_value=True,
):
if in_tf_function():
return self._call_tf(
inputs,
sequence_length=sequence_length,
add_start_value=add_start_value,
add_end_value=add_end_value,
)
else:
return self._call_python(
inputs,
sequence_length=sequence_length,
add_start_value=add_start_value,
add_end_value=add_end_value,
)

def get_config(self):
config = super().get_config()
config.update(
Expand All @@ -220,4 +376,6 @@ def get_config(self):
def compute_output_shape(self, inputs_shape):
inputs_shape = list(inputs_shape)
inputs_shape[-1] = self.sequence_length
if self.return_padding_mask:
return tuple(inputs_shape), tuple(inputs_shape)
return tuple(inputs_shape)
Loading
Loading