diff --git a/keras_hub/api/layers/v2/__init__.py b/keras_hub/api/layers/v2/__init__.py
index a2a328a36a..2b4a31a63f 100644
--- a/keras_hub/api/layers/v2/__init__.py
+++ b/keras_hub/api/layers/v2/__init__.py
@@ -7,6 +7,3 @@
from keras_hub.src.layers.preprocessing.v2.multi_segment_packer import (
MultiSegmentPacker as MultiSegmentPacker,
)
-from keras_hub.src.layers.preprocessing.v2.start_end_packer import (
- StartEndPacker as StartEndPacker,
-)
diff --git a/keras_hub/api/tokenizers/v2/__init__.py b/keras_hub/api/tokenizers/v2/__init__.py
index c80153ff31..ca2cf495a1 100644
--- a/keras_hub/api/tokenizers/v2/__init__.py
+++ b/keras_hub/api/tokenizers/v2/__init__.py
@@ -4,9 +4,6 @@
since your modifications would be overwritten.
"""
-from keras_hub.src.tokenizers.v2.byte_pair_tokenizer import (
- BytePairTokenizer as BytePairTokenizer,
-)
from keras_hub.src.tokenizers.v2.sentence_piece_tokenizer import (
SentencePieceTokenizer as SentencePieceTokenizer,
)
diff --git a/keras_hub/src/layers/preprocessing/preprocessing_layer.py b/keras_hub/src/layers/preprocessing/preprocessing_layer.py
index 5050cd529f..fe4916da40 100644
--- a/keras_hub/src/layers/preprocessing/preprocessing_layer.py
+++ b/keras_hub/src/layers/preprocessing/preprocessing_layer.py
@@ -7,12 +7,17 @@ class PreprocessingLayer(keras.layers.Layer):
"""Preprocessing layer base class."""
def __init__(self, **kwargs):
- assert_tf_libs_installed(self.__class__.__name__)
+ _allow_python_workflow = kwargs.pop("_allow_python_workflow", False)
+ if not _allow_python_workflow:
+ assert_tf_libs_installed(self.__class__.__name__)
super().__init__(**kwargs)
# Don't convert inputs (we want tf tensors not backend tensors).
self._convert_input_args = False
# Allow raw inputs like python strings.
self._allow_non_tensor_positional_args = True
+ # Allow Python workflow. Historically, KerasHub preprocessing layers
+ # required TF and TF text libraries.
+ self._allow_python_workflow = _allow_python_workflow
# Most pre-preprocessing has no build.
if not hasattr(self, "build"):
self.built = True
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer.py b/keras_hub/src/layers/preprocessing/start_end_packer.py
index efe10a4585..67cb0a4ec4 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer.py
@@ -1,8 +1,16 @@
+import keras
+import numpy as np
+
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.preprocessing_layer import (
PreprocessingLayer,
)
+from keras_hub.src.utils.tensor_utils import (
+ convert_preprocessing_outputs_python,
+)
+from keras_hub.src.utils.tensor_utils import convert_to_list
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
+from keras_hub.src.utils.tensor_utils import in_tf_function
from keras_hub.src.utils.tensor_utils import pad
from keras_hub.src.utils.tensor_utils import preprocessing_function
@@ -74,7 +82,7 @@ class StartEndPacker(PreprocessingLayer):
[ 1, 8, 9, 10, 11, 2]], dtype=int32)
Unbatched input (str).
- >>> inputs = tf.constant(["this", "is", "fun"])
+ >>> inputs = ["this", "is", "fun"]
>>> start_end_packer = keras_hub.layers.StartEndPacker(
... sequence_length=6, start_value="", end_value="",
... pad_value=""
@@ -84,7 +92,7 @@ class StartEndPacker(PreprocessingLayer):
array(['', 'this', 'is', 'fun', '', ''], dtype='>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
+ >>> inputs = [["this", "is", "fun"], ["awesome"]]
>>> start_end_packer = keras_hub.layers.StartEndPacker(
... sequence_length=6, start_value="", end_value="",
... pad_value=""
@@ -95,7 +103,7 @@ class StartEndPacker(PreprocessingLayer):
['', 'awesome', '', '', '', '']], dtype='>> inputs = tf.ragged.constant([["this", "is", "fun"], ["awesome"]])
+ >>> inputs = [["this", "is", "fun"], ["awesome"]]
>>> start_end_packer = keras_hub.layers.StartEndPacker(
... sequence_length=6, start_value=["", ""], end_value="",
... pad_value=""
@@ -117,7 +125,10 @@ def __init__(
padding_side="right",
**kwargs,
):
- super().__init__(name=name, **kwargs)
+ _allow_python_workflow = kwargs.pop("_allow_python_workflow", True)
+ super().__init__(
+ name=name, _allow_python_workflow=_allow_python_workflow, **kwargs
+ )
self.sequence_length = sequence_length
@@ -126,6 +137,8 @@ def __init__(
self._end_value = end_value
def check_special_value_type(value, value_name):
+ if value is None:
+ return None
if isinstance(value, (int, str)):
return [value]
if value and not isinstance(value, (list, tuple)):
@@ -146,7 +159,7 @@ def check_special_value_type(value, value_name):
self.padding_side = padding_side
@preprocessing_function
- def call(
+ def _call_tf(
self,
inputs,
sequence_length=None,
@@ -203,6 +216,152 @@ def call(
return outputs, mask
return outputs
+ def _call_python(
+ self,
+ inputs,
+ sequence_length=None,
+ add_start_value=True,
+ add_end_value=True,
+ ):
+ def _canonicalize_inputs(inputs):
+ if isinstance(inputs, (tuple, list)):
+ inputs = keras.tree.map_structure(convert_to_list, inputs)
+ if inputs and isinstance(inputs[0], (tuple, list)):
+ return inputs, True
+ else:
+ return [inputs], False
+ elif tf is not None and isinstance(
+ inputs, (tf.Tensor, tf.RaggedTensor)
+ ):
+ unbatched = inputs.shape.rank == 1
+ if unbatched:
+ inputs = tf.expand_dims(inputs, 0)
+ if isinstance(inputs, tf.Tensor):
+ inputs = convert_to_list(inputs)
+ else:
+ inputs = inputs.to_list()
+ return inputs, not unbatched
+ elif keras.ops.is_tensor(inputs):
+ inputs = convert_to_list(inputs)
+ if inputs and isinstance(inputs[0], (tuple, list)):
+ return inputs, True
+ else:
+ return [inputs], False
+ else:
+ raise ValueError(
+ f"Input should be a list or a list of lists. "
+ f"Received: {inputs}"
+ )
+
+ def _get_type(inputs):
+ for sequence in inputs:
+ if sequence is not None and len(sequence) > 0:
+ return type(sequence[0])
+ return int # Default to int if all sequences are empty.
+
+ def _canonicalize_value(values, input_type):
+ if input_type is str:
+ return [str(v) for v in values]
+ else:
+ return [int(v) for v in values]
+
+ def _pad(x, pad_value, padding_side, sequence_length, input_type=None):
+ if padding_side not in ("left", "right"):
+ raise ValueError(
+ "padding_side must be 'left' or 'right'. "
+ f"Received: {padding_side}"
+ )
+ if pad_value is None:
+ pad_value = "" if input_type is str else 0
+ if padding_side == "right":
+ x = [
+ seq + [pad_value] * (sequence_length - len(seq))
+ for seq in x
+ ]
+ else:
+ x = [
+ [pad_value] * (sequence_length - len(seq)) + seq
+ for seq in x
+ ]
+ return x
+
+ def _canonicalize_outputs(outputs, dtype=None):
+ flat_outputs = keras.tree.flatten(outputs)
+ if not flat_outputs:
+ return np.array(outputs, dtype=dtype or "int32")
+ first_element = flat_outputs[0]
+ if not isinstance(first_element, str):
+ return np.array(outputs, dtype=dtype or "int32")
+ else:
+ return outputs
+
+ inputs, batched = _canonicalize_inputs(inputs)
+ input_type = _get_type(inputs)
+ sequence_length = sequence_length or self.sequence_length
+ x = inputs
+
+ # Truncate and normalize to list of lists.
+ truncation_length = sequence_length
+ if add_start_value and self.start_value is not None:
+ truncation_length -= len(self.start_value)
+ if add_end_value and self.end_value is not None:
+ truncation_length -= len(self.end_value)
+ x = [list(seq)[:truncation_length] for seq in x]
+
+ # Concatenate start and end tokens.
+ if add_start_value and self.start_value is not None:
+ start_value = _canonicalize_value(self.start_value, input_type)
+ x = [start_value + seq for seq in x]
+ if add_end_value and self.end_value is not None:
+ end_value = _canonicalize_value(self.end_value, input_type)
+ x = [seq + end_value for seq in x]
+
+ # Pad to desired length.
+ outputs = _pad(
+ x,
+ pad_value=self.pad_value,
+ padding_side=self.padding_side,
+ sequence_length=sequence_length,
+ input_type=input_type,
+ )
+ outputs = _canonicalize_outputs(outputs)
+ outputs = outputs[0] if not batched else outputs
+
+ if self.return_padding_mask:
+ masks = keras.tree.map_structure(lambda _: True, x)
+ masks = _pad(
+ masks,
+ pad_value=False,
+ padding_side=self.padding_side,
+ sequence_length=sequence_length,
+ )
+ masks = masks[0] if not batched else masks
+ masks = _canonicalize_outputs(masks, dtype="bool")
+ return convert_preprocessing_outputs_python((outputs, masks))
+ return convert_preprocessing_outputs_python(outputs)
+
+ def call(
+ self,
+ inputs,
+ sequence_length=None,
+ add_start_value=True,
+ add_end_value=True,
+ ):
+ if not self._allow_python_workflow or in_tf_function():
+ return self._call_tf(
+ inputs,
+ sequence_length=sequence_length,
+ add_start_value=add_start_value,
+ add_end_value=add_end_value,
+ )
+ else:
+ return self._call_python(
+ inputs,
+ sequence_length=sequence_length,
+ add_start_value=add_start_value,
+ add_end_value=add_end_value,
+ )
+
def get_config(self):
config = super().get_config()
config.update(
@@ -220,4 +379,6 @@ def get_config(self):
def compute_output_shape(self, inputs_shape):
inputs_shape = list(inputs_shape)
inputs_shape[-1] = self.sequence_length
+ if self.return_padding_mask:
+ return tuple(inputs_shape), tuple(inputs_shape)
return tuple(inputs_shape)
diff --git a/keras_hub/src/layers/preprocessing/start_end_packer_test.py b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
index 78f65405f0..63d52e1c54 100644
--- a/keras_hub/src/layers/preprocessing/start_end_packer_test.py
+++ b/keras_hub/src/layers/preprocessing/start_end_packer_test.py
@@ -1,20 +1,29 @@
import tensorflow as tf
+from absl.testing import parameterized
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_hub.src.tests.test_case import TestCase
class StartEndPackerTest(TestCase):
- def test_dense_input(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_dense_input(self, allow_python_workflow):
# right padding
input_data = [5, 6, 7]
- start_end_packer = StartEndPacker(sequence_length=5)
+ start_end_packer = StartEndPacker(
+ sequence_length=5, _allow_python_workflow=allow_python_workflow
+ )
output = start_end_packer(input_data)
expected_output = [5, 6, 7, 0, 0]
self.assertAllEqual(output, expected_output)
# left padding
start_end_packer = StartEndPacker(
- sequence_length=5, padding_side="left"
+ sequence_length=5,
+ padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [0, 0, 5, 6, 7]
@@ -28,54 +37,85 @@ def test_bfloat16_dtype(self):
output = start_end_packer(input_data)
self.assertDTypeEqual(output, "int32")
- def test_dense_2D_input(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_dense_2D_input(self, allow_python_workflow):
# right padding
input_data = [[5, 6, 7]]
- start_end_packer = StartEndPacker(sequence_length=5)
+ start_end_packer = StartEndPacker(
+ sequence_length=5, _allow_python_workflow=allow_python_workflow
+ )
output = start_end_packer(input_data)
expected_output = [[5, 6, 7, 0, 0]]
self.assertAllEqual(output, expected_output)
# left padding
start_end_packer = StartEndPacker(
- sequence_length=5, padding_side="left"
+ sequence_length=5,
+ padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[0, 0, 5, 6, 7]]
self.assertAllEqual(output, expected_output)
- def test_ragged_input(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_ragged_input(self, allow_python_workflow):
# right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
- start_end_packer = StartEndPacker(sequence_length=5)
+ start_end_packer = StartEndPacker(
+ sequence_length=5, _allow_python_workflow=allow_python_workflow
+ )
output = start_end_packer(input_data)
expected_output = [[5, 6, 7, 0, 0], [8, 9, 10, 11, 0]]
self.assertAllEqual(output, expected_output)
# left padding
start_end_packer = StartEndPacker(
- sequence_length=5, padding_side="left"
+ sequence_length=5,
+ padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[0, 0, 5, 6, 7], [0, 8, 9, 10, 11]]
self.assertAllEqual(output, expected_output)
- def test_start_end_token(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_start_end_token(self, allow_python_workflow):
# right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
- sequence_length=6, start_value=1, end_value=2
+ sequence_length=6,
+ start_value=1,
+ end_value=2,
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)
# left padding
start_end_packer = StartEndPacker(
- sequence_length=6, start_value=1, end_value=2, padding_side="left"
+ sequence_length=6,
+ start_value=1,
+ end_value=2,
+ padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)
- def test_multiple_start_end_tokens(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_multiple_start_end_tokens(self, allow_python_workflow):
# right padding
input_data = [[5, 6, 7], [8, 9, 10, 11, 12, 13]]
start_end_packer = StartEndPacker(
@@ -83,6 +123,7 @@ def test_multiple_start_end_tokens(self):
start_value=[1, 2],
end_value=[3, 4],
pad_value=0,
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[1, 2, 5, 6, 7, 3, 4, 0], [1, 2, 8, 9, 10, 11, 3, 4]]
@@ -95,16 +136,25 @@ def test_multiple_start_end_tokens(self):
end_value=[3, 4],
pad_value=0,
padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[0, 1, 2, 5, 6, 7, 3, 4], [1, 2, 8, 9, 10, 11, 3, 4]]
self.assertAllEqual(output, expected_output)
- def test_start_end_padding_value(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_start_end_padding_value(self, allow_python_workflow):
# right padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
- sequence_length=7, start_value=1, end_value=2, pad_value=3
+ sequence_length=7,
+ start_value=1,
+ end_value=2,
+ pad_value=3,
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[1, 5, 6, 7, 2, 3, 3], [1, 8, 9, 10, 11, 2, 3]]
@@ -117,18 +167,24 @@ def test_start_end_padding_value(self):
end_value=2,
pad_value=3,
padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[3, 3, 1, 5, 6, 7, 2], [3, 1, 8, 9, 10, 11, 2]]
self.assertAllEqual(output, expected_output)
- def test_truncation(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_truncation(self, allow_python_workflow):
# right padding
input_data = list(range(10))
packer = StartEndPacker(
sequence_length=7,
start_value=98,
end_value=99,
+ _allow_python_workflow=allow_python_workflow,
)
expected_output = [98, 0, 1, 2, 3, 4, 99]
self.assertAllEqual(packer(input_data), expected_output)
@@ -139,15 +195,21 @@ def test_truncation(self):
start_value=98,
end_value=99,
padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
self.assertAllEqual(packer(input_data), expected_output)
- def test_truncation_wo_endvalue(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_truncation_wo_endvalue(self, allow_python_workflow):
# right padding
input_data = list(range(10))
packer = StartEndPacker(
sequence_length=7,
start_value=98,
+ _allow_python_workflow=allow_python_workflow,
)
expected_output = [98, 0, 1, 2, 3, 4, 5]
self.assertAllEqual(packer(input_data), expected_output)
@@ -157,14 +219,23 @@ def test_truncation_wo_endvalue(self):
sequence_length=7,
start_value=98,
padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
self.assertAllEqual(packer(input_data), expected_output)
- def test_end_token_value_during_truncation(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_end_token_value_during_truncation(self, allow_python_workflow):
# right padding
input_data = [[5, 6], [8, 9, 10, 11, 12, 13]]
start_end_packer = StartEndPacker(
- sequence_length=5, start_value=1, end_value=2, pad_value=0
+ sequence_length=5,
+ start_value=1,
+ end_value=2,
+ pad_value=0,
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[1, 5, 6, 2, 0], [1, 8, 9, 10, 2]]
@@ -177,12 +248,17 @@ def test_end_token_value_during_truncation(self):
end_value=2,
pad_value=0,
padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [[0, 1, 5, 6, 2], [1, 8, 9, 10, 2]]
self.assertAllEqual(output, expected_output)
- def test_string_input(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_string_input(self, allow_python_workflow):
# right padding
input_data = [["KerasHub", "is", "awesome"], ["amazing"]]
start_end_packer = StartEndPacker(
@@ -190,6 +266,7 @@ def test_string_input(self):
start_value="[START]",
end_value="[END]",
pad_value="[PAD]",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [
@@ -205,6 +282,7 @@ def test_string_input(self):
end_value="[END]",
pad_value="[PAD]",
padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [
@@ -213,7 +291,13 @@ def test_string_input(self):
]
self.assertAllEqual(output, expected_output)
- def test_string_input_with_multiple_special_values(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_string_input_with_multiple_special_values(
+ self, allow_python_workflow
+ ):
# right padding
input_data = [["KerasHub", "is", "awesome"], ["amazing"]]
start_end_packer = StartEndPacker(
@@ -221,6 +305,7 @@ def test_string_input_with_multiple_special_values(self):
start_value=["[END]", "[START]"],
end_value="[END]",
pad_value="[PAD]",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [
@@ -236,6 +321,7 @@ def test_string_input_with_multiple_special_values(self):
end_value="[END]",
pad_value="[PAD]",
padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output = start_end_packer(input_data)
expected_output = [
@@ -262,21 +348,35 @@ def test_batch(self):
exp_output = [[1, 5, 6, 7, 2, 3, 3], [1, 8, 9, 10, 11, 2, 3]]
self.assertAllEqual(output, exp_output)
- def test_call_overrides(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_call_overrides(self, allow_python_workflow):
x = [5, 6, 7]
- packer = StartEndPacker(start_value=1, end_value=2, sequence_length=4)
+ packer = StartEndPacker(
+ start_value=1,
+ end_value=2,
+ sequence_length=4,
+ _allow_python_workflow=allow_python_workflow,
+ )
self.assertAllEqual(packer(x), [1, 5, 6, 2])
self.assertAllEqual(packer(x, add_start_value=False), [5, 6, 7, 2])
self.assertAllEqual(packer(x, add_end_value=False), [1, 5, 6, 7])
self.assertAllEqual(packer(x, sequence_length=2), [1, 2])
- def test_get_config(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_get_config(self, allow_python_workflow):
start_end_packer = StartEndPacker(
sequence_length=512,
start_value=10,
end_value=20,
pad_value=100,
name="start_end_packer_test",
+ _allow_python_workflow=allow_python_workflow,
)
config = start_end_packer.get_config()
@@ -289,7 +389,11 @@ def test_get_config(self):
self.assertEqual(config, {**config, **expected_config_subset})
- def test_return_padding_mask(self):
+ @parameterized.named_parameters(
+ ("allow_python_workflow", True),
+ ("disallow_python_workflow", False),
+ )
+ def test_return_padding_mask(self, allow_python_workflow):
# right_padding
input_data = [[5, 6, 7], [8, 9, 10, 11]]
start_end_packer = StartEndPacker(
@@ -297,6 +401,7 @@ def test_return_padding_mask(self):
start_value=1,
end_value=2,
return_padding_mask=True,
+ _allow_python_workflow=allow_python_workflow,
)
output, padding_mask = start_end_packer(input_data)
expected_output = [[1, 5, 6, 7, 2, 0], [1, 8, 9, 10, 11, 2]]
@@ -304,7 +409,6 @@ def test_return_padding_mask(self):
[True, True, True, True, True, False],
[True, True, True, True, True, True],
]
- print(padding_mask)
self.assertAllEqual(output, expected_output)
self.assertAllEqual(padding_mask, expected_padding_mask)
@@ -315,6 +419,7 @@ def test_return_padding_mask(self):
end_value=2,
return_padding_mask=True,
padding_side="left",
+ _allow_python_workflow=allow_python_workflow,
)
output, padding_mask = start_end_packer(input_data)
expected_output = [[0, 1, 5, 6, 7, 2], [1, 8, 9, 10, 11, 2]]
diff --git a/keras_hub/src/layers/preprocessing/v2/start_end_packer.py b/keras_hub/src/layers/preprocessing/v2/start_end_packer.py
deleted file mode 100644
index 85bf0cfde1..0000000000
--- a/keras_hub/src/layers/preprocessing/v2/start_end_packer.py
+++ /dev/null
@@ -1,266 +0,0 @@
-import keras
-import numpy as np
-
-from keras_hub.src.api_export import keras_hub_export
-from keras_hub.src.layers.preprocessing.v2.preprocessing_layer import (
- PreprocessingLayer,
-)
-from keras_hub.src.utils.tensor_utils import convert_to_list
-
-
-@keras_hub_export("keras_hub.layers.v2.StartEndPacker")
-class StartEndPacker(PreprocessingLayer):
- """Adds start and end tokens to a sequence and pads to a fixed length.
-
- This layer is useful when tokenizing inputs for tasks like translation,
- where each sequence should include a start and end marker. It should
- be called after tokenization. The layer will first trim inputs to fit, then
- add start/end tokens, and finally pad, if necessary, to `sequence_length`.
-
- Input data should be passed as lists. For batched input, inputs should be a
- list of lists. For unbatched inputs, each element should be a list.
-
- Args:
- sequence_length: int. The desired output length.
- start_value: int/str/list/tuple. The ID(s) or token(s) that are to be
- placed at the start of each sequence. The dtype must match the dtype
- of the input tensors to the layer. If `None`, no start value will be
- added.
- end_value: int/str/list/tuple. The ID(s) or token(s) that are to be
- placed at the end of each input segment. The dtype must match the
- dtype of the input tensors to the layer. If `None`, no end value
- will be added.
- pad_value: int/str. The ID or token that is to be placed into the
- unused positions after the last segment in the sequence. If `None`,
- 0 or "" will be added depending on the dtype of the input tensor.
- return_padding_mask: bool. Whether to return a boolean padding mask of
- all locations that are filled in with the `pad_value`.
- padding_side: str. Whether to pad the input on the "left" or "right".
- Defaults to "right".
-
- Call arguments:
- inputs: A list or a list of lists of python strings or ints.
- sequence_length: Pass to override the configured `sequence_length` of
- the layer.
- add_start_value: Pass `False` to not append a start value for this
- input.
- add_end_value: Pass `False` to not append an end value for this
- input.
-
- Examples:
-
- Unbatched input (int).
- >>> inputs = [5, 6, 7]
- >>> start_end_packer = keras_hub.layers.StartEndPacker(
- ... sequence_length=7, start_value=1, end_value=2,
- ... )
- >>> outputs = start_end_packer(inputs)
- >>> np.array(outputs)
- array([1, 5, 6, 7, 2, 0, 0], dtype=int32)
-
- Batched input (int).
- >>> inputs = [[5, 6, 7], [8, 9, 10, 11, 12, 13, 14]]
- >>> start_end_packer = keras_hub.layers.StartEndPacker(
- ... sequence_length=6, start_value=1, end_value=2,
- ... )
- >>> outputs = start_end_packer(inputs)
- >>> np.array(outputs)
- array([[ 1, 5, 6, 7, 2, 0],
- [ 1, 8, 9, 10, 11, 2]], dtype=int32)
-
- Unbatched input (str).
- >>> inputs = ["this", "is", "fun"]
- >>> start_end_packer = keras_hub.layers.StartEndPacker(
- ... sequence_length=6, start_value="", end_value="",
- ... pad_value=""
- ... )
- >>> outputs = start_end_packer(inputs)
- >>> np.array(outputs).astype("U")
- array(['', 'this', 'is', 'fun', '', ''], dtype='>> inputs = [["this", "is", "fun"], ["awesome"]]
- >>> start_end_packer = keras_hub.layers.StartEndPacker(
- ... sequence_length=6, start_value="", end_value="",
- ... pad_value=""
- ... )
- >>> outputs = start_end_packer(inputs)
- >>> np.array(outputs).astype("U")
- array([['', 'this', 'is', 'fun', '', ''],
- ['', 'awesome', '', '', '', '']], dtype='>> inputs = [["this", "is", "fun"], ["awesome"]]
- >>> start_end_packer = keras_hub.layers.StartEndPacker(
- ... sequence_length=6, start_value=["", ""], end_value="",
- ... pad_value=""
- ... )
- >>> outputs = start_end_packer(inputs)
- >>> np.array(outputs).astype("U")
- array([['', '', 'this', 'is', 'fun', ''],
- ['', '', 'awesome', '', '', '']], dtype='", "", "", "air", "Ġair", "plane", "Ġat"]
- self.vocab += ["port", ""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", "", "", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = BartTokenizer(
vocabulary=self.vocab, merges=self.merges
)
@@ -37,12 +41,12 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "encoder_token_ids": [[0, 4, 5, 6, 2]],
+ "encoder_token_ids": [[3, 27, 18, 28, 0]],
"encoder_padding_mask": [[1, 1, 1, 1, 1]],
- "decoder_token_ids": [[2, 0, 4, 5, 4, 7, 2, 1]],
+ "decoder_token_ids": [[0, 3, 27, 18, 27, 20, 0, 2]],
"decoder_padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
- [[0, 4, 5, 4, 7, 2, 1, 1]],
+ [[3, 27, 18, 27, 20, 0, 2, 2]],
[[1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 0.0, 0.0]],
),
token_id_key="decoder_token_ids",
@@ -58,9 +62,9 @@ def test_generate_preprocess(self):
self.assertAllClose(
output,
{
- "encoder_token_ids": [[0, 4, 5, 6, 2]],
+ "encoder_token_ids": [[3, 27, 18, 28, 0]],
"encoder_padding_mask": [[1, 1, 1, 1, 1]],
- "decoder_token_ids": [[2, 0, 4, 5, 4, 7, 1, 1]],
+ "decoder_token_ids": [[0, 3, 27, 18, 27, 20, 2, 2]],
"decoder_padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
},
)
@@ -68,7 +72,7 @@ def test_generate_preprocess(self):
def test_generate_postprocess(self):
preprocessor = BartSeq2SeqLMPreprocessor(**self.init_kwargs)
input_data = {
- "decoder_token_ids": [0, 4, 5, 6, 2],
+ "decoder_token_ids": [3, 27, 18, 28, 0],
"decoder_padding_mask": [1, 1, 1, 1, 1],
}
output = preprocessor.generate_postprocess(input_data)
diff --git a/keras_hub/src/models/bart/bart_seq_2_seq_lm_test.py b/keras_hub/src/models/bart/bart_seq_2_seq_lm_test.py
index 7570794630..10cbd8bd13 100644
--- a/keras_hub/src/models/bart/bart_seq_2_seq_lm_test.py
+++ b/keras_hub/src/models/bart/bart_seq_2_seq_lm_test.py
@@ -14,19 +14,24 @@
class BartSeq2SeqLMTest(TestCase):
def setUp(self):
- self.vocab = ["", "", "", "air", "Ġair", "plane", "Ġat"]
- self.vocab += ["port", ""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", "", "", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = BartSeq2SeqLMPreprocessor(
BartTokenizer(vocabulary=self.vocab, merges=self.merges),
encoder_sequence_length=12,
decoder_sequence_length=10,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = BartBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_heads=2,
hidden_dim=4,
@@ -53,7 +58,7 @@ def test_causal_lm_basics(self):
cls=BartSeq2SeqLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 10, 9),
+ expected_output_shape=(2, 10, self.vocabulary_size),
)
def test_generate(self):
diff --git a/keras_hub/src/models/bart/bart_tokenizer_test.py b/keras_hub/src/models/bart/bart_tokenizer_test.py
index e5eb9ad9e7..9a69d95708 100644
--- a/keras_hub/src/models/bart/bart_tokenizer_test.py
+++ b/keras_hub/src/models/bart/bart_tokenizer_test.py
@@ -6,12 +6,16 @@
class BartTokenizerTest(TestCase):
def setUp(self):
- self.vocab = ["", "", "", "air", "Ġair", "plane", "Ġat"]
- self.vocab += ["port", ""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", "", "", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
" airplane at airport",
@@ -23,7 +27,10 @@ def test_tokenizer_basics(self):
cls=BartTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]],
+ expected_output=[
+ [3, 27, 18, 28, 27, 20, 0, 2],
+ [27, 18, 27, 20],
+ ],
expected_detokenize_output=[
" airplane at airport",
" airplane airport",
diff --git a/keras_hub/src/models/bloom/bloom_causal_lm_preprocessor_test.py b/keras_hub/src/models/bloom/bloom_causal_lm_preprocessor_test.py
index c8d13f1e70..b334e94a9e 100644
--- a/keras_hub/src/models/bloom/bloom_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/bloom/bloom_causal_lm_preprocessor_test.py
@@ -9,12 +9,16 @@
class BloomCausalLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["", "", ""]
- self.vocab += ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", "", "", "", "!"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = BloomTokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -32,10 +36,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[1, 4, 6, 7, 5, 8, 2, 0]],
+ "token_ids": [[3, 7, 19, 29, 28, 21, 1, 2]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
- [[4, 6, 7, 5, 8, 2, 0, 0]], # Pass through labels.
+ [[7, 19, 29, 28, 21, 1, 2, 2]], # Pass through labels.
[[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights.
),
)
@@ -49,21 +53,21 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[4, 6, 7, 5, 8, 0, 0, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[7, 19, 29, 28, 21, 2, 2, 2]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
- self.assertAllEqual(y, [[6, 7, 5, 8, 0, 0, 0, 0]] * 4)
+ self.assertAllEqual(y, [[19, 29, 28, 21, 2, 2, 2, 2]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = BloomCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- self.assertAllEqual(x["token_ids"], [1, 4, 6, 7, 5, 8, 0, 0])
+ self.assertAllEqual(x["token_ids"], [3, 7, 19, 29, 28, 21, 2, 2])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": [1, 4, 6, 7, 5, 8, 0, 0],
+ "token_ids": [3, 7, 19, 29, 28, 21, 2, 2],
"padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
}
preprocessor = BloomCausalLMPreprocessor(**self.init_kwargs)
diff --git a/keras_hub/src/models/bloom/bloom_causal_lm_test.py b/keras_hub/src/models/bloom/bloom_causal_lm_test.py
index 2814fb1b79..bab9a0655a 100644
--- a/keras_hub/src/models/bloom/bloom_causal_lm_test.py
+++ b/keras_hub/src/models/bloom/bloom_causal_lm_test.py
@@ -14,12 +14,16 @@
class BloomCausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["", "", "", ""]
- self.vocab += ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", "", "", "", "!"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = BloomTokenizer(
vocabulary=self.vocab, merges=self.merges
)
@@ -27,8 +31,9 @@ def setUp(self):
self.tokenizer,
sequence_length=8,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = BloomBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_heads=2,
hidden_dim=4,
@@ -47,12 +52,11 @@ def setUp(self):
self.input_data = self.preprocessor(*self.train_data)[0]
def test_causal_lm_basics(self):
- vocabulary_size = self.tokenizer.vocabulary_size()
self.run_task_test(
cls=BloomCausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 8, vocabulary_size),
+ expected_output_shape=(2, 8, self.vocabulary_size),
)
def test_generate(self):
diff --git a/keras_hub/src/models/bloom/bloom_tokenizer_test.py b/keras_hub/src/models/bloom/bloom_tokenizer_test.py
index 51b9a9087a..10e071d31c 100644
--- a/keras_hub/src/models/bloom/bloom_tokenizer_test.py
+++ b/keras_hub/src/models/bloom/bloom_tokenizer_test.py
@@ -6,12 +6,16 @@
class BloomTokenizerTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["", "", ""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "", "", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
"airplane at airport",
@@ -23,14 +27,21 @@ def test_tokenizer_basics(self):
cls=BloomTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- expected_output=[[6, 1, 3, 4, 2, 5, 8], [6, 2, 3, 2, 5, 8]],
+ expected_output=[
+ [3, 6, 18, 28, 27, 20, 2],
+ [3, 27, 18, 27, 20, 2],
+ ],
+ expected_detokenize_output=[
+ "airplane at airport",
+ " airplane airport",
+ ],
)
def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
BloomTokenizer(vocabulary=["a", "b", "c"], merges=[])
- @pytest.mark.extra_large
+ @pytest.mark.large
def test_smallest_preset(self):
self.run_preset_test(
cls=BloomTokenizer,
diff --git a/keras_hub/src/models/causal_lm_preprocessor.py b/keras_hub/src/models/causal_lm_preprocessor.py
index 2bc1f7a3ce..e844074766 100644
--- a/keras_hub/src/models/causal_lm_preprocessor.py
+++ b/keras_hub/src/models/causal_lm_preprocessor.py
@@ -3,6 +3,7 @@
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_hub.src.models.preprocessor import Preprocessor
+from keras_hub.src.utils.tensor_utils import in_tf_function
from keras_hub.src.utils.tensor_utils import preprocessing_function
from keras_hub.src.utils.tensor_utils import strip_to_ragged
@@ -66,7 +67,10 @@ def __init__(
add_end_token=True,
**kwargs,
):
- super().__init__(**kwargs)
+ _allow_python_workflow = kwargs.pop("_allow_python_workflow", True)
+ super().__init__(
+ _allow_python_workflow=_allow_python_workflow, **kwargs
+ )
self.tokenizer = tokenizer
self.packer = None
self.sequence_length = sequence_length
@@ -85,14 +89,7 @@ def build(self, input_shape):
)
self.built = True
- @preprocessing_function
- def call(
- self,
- x,
- y=None,
- sample_weight=None,
- sequence_length=None,
- ):
+ def _call(self, x, y=None, sample_weight=None, sequence_length=None):
sequence_length = sequence_length or self.sequence_length
x = self.tokenizer(x)
# Pad with one extra token to account for the truncation below.
@@ -112,22 +109,28 @@ def call(
return keras.utils.pack_x_y_sample_weight(x, y, sample_weight)
@preprocessing_function
- def generate_preprocess(
- self,
- x,
- sequence_length=None,
- ):
- """Convert strings to integer token input for generation.
-
- Similar to calling the layer for training, this method takes in strings
- or tensor strings, tokenizes and packs the input, and computes a padding
- mask masking all inputs not filled in with a padded value.
+ def _call_tf(self, x, y=None, sample_weight=None, sequence_length=None):
+ return self._call(
+ x, y=y, sample_weight=sample_weight, sequence_length=sequence_length
+ )
- Unlike calling the layer for training, this method does not compute
- labels and will never append a `tokenizer.end_token_id` to the end of
- the sequence (as generation is expected to continue at the end of the
- inputted prompt).
- """
+ def call(self, x, y=None, sample_weight=None, sequence_length=None):
+ if in_tf_function():
+ return self._call_tf(
+ x,
+ y=y,
+ sample_weight=sample_weight,
+ sequence_length=sequence_length,
+ )
+ else:
+ return self._call(
+ x,
+ y=y,
+ sample_weight=sample_weight,
+ sequence_length=sequence_length,
+ )
+
+ def _generate_preprocess(self, x, sequence_length=None):
if not self.built:
self.build(None)
@@ -141,16 +144,54 @@ def generate_preprocess(
}
@preprocessing_function
- def generate_postprocess(
- self,
- x,
- ):
- """Convert integer token output to strings for generation.
+ def _generate_preprocess_tf(self, x, sequence_length=None):
+ return self._generate_preprocess(x, sequence_length=sequence_length)
- This method reverses `generate_preprocess()`, by first removing all
- padding and start/end tokens, and then converting the integer sequence
- back to a string.
+ def generate_preprocess(self, x, sequence_length=None):
+ """Convert strings to integer token input for generation.
+
+ Similar to calling the layer for training, this method takes in strings
+ or tensor strings, tokenizes and packs the input, and computes a padding
+ mask masking all inputs not filled in with a padded value.
+
+ Unlike calling the layer for training, this method does not compute
+ labels and will never append a `tokenizer.end_token_id` to the end of
+ the sequence (as generation is expected to continue at the end of the
+ inputted prompt).
"""
+ if in_tf_function():
+ return self._generate_preprocess_tf(
+ x, sequence_length=sequence_length
+ )
+ else:
+ return self._generate_preprocess(x, sequence_length=sequence_length)
+
+ def _generate_postprocess(self, x):
+ if not self.built:
+ self.build(None)
+
+ def _strip_to_ragged(token_ids, masks, ids_to_strip):
+ """Remove masked and special tokens from a sequence."""
+ for id in ids_to_strip:
+ masks = masks & (token_ids != id)
+ if token_ids.ndim == 1:
+ token_ids = token_ids[masks].tolist()
+ else:
+ ragged_ids = []
+ for i in range(token_ids.shape[0]):
+ ragged_ids.append(token_ids[i][masks[i]].tolist())
+ token_ids = ragged_ids
+ return token_ids
+
+ token_ids, padding_mask = x["token_ids"], x["padding_mask"]
+ ids_to_strip = self.tokenizer.special_token_ids
+ token_ids = keras.ops.convert_to_numpy(token_ids).astype("int32")
+ padding_mask = keras.ops.convert_to_numpy(padding_mask).astype("bool")
+ token_ids = _strip_to_ragged(token_ids, padding_mask, ids_to_strip)
+ return self.tokenizer.detokenize(token_ids)
+
+ @preprocessing_function
+ def _generate_postprocess_tf(self, x):
if not self.built:
self.build(None)
@@ -159,6 +200,18 @@ def generate_postprocess(
token_ids = strip_to_ragged(token_ids, padding_mask, ids_to_strip)
return self.tokenizer.detokenize(token_ids)
+ def generate_postprocess(self, x):
+ """Convert integer token output to strings for generation.
+
+ This method reverses `generate_preprocess()`, by first removing all
+ padding and start/end tokens, and then converting the integer sequence
+ back to a string.
+ """
+ if in_tf_function():
+ return self._generate_postprocess_tf(x)
+ else:
+ return self._generate_postprocess(x)
+
def get_config(self):
config = super().get_config()
config.update(
diff --git a/keras_hub/src/models/clip/clip_preprocessor_test.py b/keras_hub/src/models/clip/clip_preprocessor_test.py
index bec45a25d5..c763274d4d 100644
--- a/keras_hub/src/models/clip/clip_preprocessor_test.py
+++ b/keras_hub/src/models/clip/clip_preprocessor_test.py
@@ -9,11 +9,15 @@
class CLIPPreprocessorTest(TestCase):
def setUp(self):
- vocab = ["air", "plane", "port"]
- vocab += ["<|endoftext|>", "<|startoftext|>"]
- vocab = dict([(token, i + 1) for i, token in enumerate(vocab)])
merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"]
merges += ["po rt", "pla ne"]
+ vocab = []
+ for merge in merges:
+ a, b = merge.split(" ")
+ vocab.extend([a, b, a + b])
+ vocab += ["<|endoftext|>", "<|startoftext|>"]
+ vocab = sorted(set(vocab)) # Remove duplicates
+ vocab = dict([(token, i) for i, token in enumerate(vocab)])
self.tokenizer = CLIPTokenizer(vocabulary=vocab, merges=merges)
self.image_converter = CLIPImageConverter(
(224, 224),
@@ -37,7 +41,7 @@ def test_preprocessor_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output={
- "token_ids": [[5, 1, 2, 1, 3, 4, 0, 0]],
+ "token_ids": [[1, 4, 14, 4, 16, 0, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
"images": np.ones([1, 224, 224, 3]) * -1.0,
},
@@ -71,7 +75,7 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[1, 2, 1, 3, 0, 0, 0, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[4, 14, 4, 16, 0, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
def test_sequence_length_override(self):
@@ -81,7 +85,7 @@ def test_sequence_length_override(self):
}
preprocessor = CLIPPreprocessor(**self.init_kwargs)
x = preprocessor(input_data, sequence_length=5)
- self.assertAllEqual(x["token_ids"], [5, 1, 2, 1, 4])
+ self.assertAllEqual(x["token_ids"], [1, 4, 14, 4, 0])
@pytest.mark.kaggle_key_required
@pytest.mark.extra_large
@@ -93,3 +97,34 @@ def test_all_presets(self):
preset=preset,
input_data=self.input_data,
)
+
+
+class CLIPPreprocessorDisallowPythonWorkflowTest(CLIPPreprocessorTest):
+ def setUp(self):
+ merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"]
+ merges += ["po rt", "pla ne"]
+ vocab = []
+ for merge in merges:
+ a, b = merge.split(" ")
+ vocab.extend([a, b, a + b])
+ vocab += ["<|endoftext|>", "<|startoftext|>"]
+ vocab = sorted(set(vocab)) # Remove duplicates
+ vocab = dict([(token, i) for i, token in enumerate(vocab)])
+ self.tokenizer = CLIPTokenizer(
+ vocabulary=vocab, merges=merges, _allow_python_workflow=False
+ )
+ self.image_converter = CLIPImageConverter(
+ (224, 224),
+ [2.0 / 255.0] * 3,
+ [-1.0] * 3,
+ interpolation="bicubic",
+ )
+ self.init_kwargs = {
+ "tokenizer": self.tokenizer,
+ "image_converter": self.image_converter,
+ "sequence_length": 8,
+ }
+ self.input_data = {
+ "prompts": [" airplane airport"],
+ "images": [np.zeros([512, 512, 3])],
+ }
diff --git a/keras_hub/src/models/clip/clip_tokenizer.py b/keras_hub/src/models/clip/clip_tokenizer.py
index 44e8832996..de00b09133 100644
--- a/keras_hub/src/models/clip/clip_tokenizer.py
+++ b/keras_hub/src/models/clip/clip_tokenizer.py
@@ -1,8 +1,15 @@
+import tokenizers
+from tokenizers import decoders
+from tokenizers import models
+from tokenizers import normalizers
+from tokenizers import pre_tokenizers
+from tokenizers import processors
+
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.models.clip.clip_backbone import CLIPBackbone
from keras_hub.src.tokenizers.byte_pair_tokenizer import BytePairTokenizer
-from keras_hub.src.tokenizers.byte_pair_tokenizer import convert_to_ragged_batch
from keras_hub.src.tokenizers.byte_pair_tokenizer import split_strings_for_bpe
+from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
from keras_hub.src.utils.tensor_utils import preprocessing_function
try:
@@ -79,20 +86,95 @@ def __init__(
**kwargs,
)
+ def _set_vocabulary_and_merges_tokenizers(self, vocabulary, merges):
+ # CLIPTokenizer has the extra settings.
+ # Ref: transformers.models.clip.tokenization_clip
+ vocabulary = self.vocabulary.copy()
+ merges = self.merges
+ _merges = []
+ for merge in merges:
+ if "#version:" in merge.lstrip():
+ continue
+ a, b = str(merge).split(" ")
+ if a not in vocabulary or b not in vocabulary:
+ raise ValueError(
+ f"Merge rule '{merge}' contains token '{a}' or '{b}' that "
+ "is not in the vocabulary."
+ )
+ _merges.append((a, b))
+ self._tokenizer = tokenizers.Tokenizer(
+ models.BPE(
+ vocab=vocabulary,
+ merges=_merges,
+ continuing_subword_prefix="",
+ end_of_word_suffix="",
+ fuse_unk=False,
+ unk_token="<|endoftext|>",
+ )
+ )
+ if self.unsplittable_tokens:
+ self._tokenizer.add_special_tokens(self.unsplittable_tokens)
+ self._tokenizer.normalizer = normalizers.Sequence(
+ [
+ normalizers.NFC(),
+ normalizers.Replace(tokenizers.Regex(r"\s+"), " "),
+ normalizers.Lowercase(),
+ ]
+ )
+ self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.Split(
+ tokenizers.Regex(
+ r"""<\|startoftext\|>|<\|endoftext\|>|'s|'t|'re|'ve|'m|'ll|'d|[\p{L}]+|[\p{N}]|[^\s\p{L}\p{N}]+"""
+ ),
+ behavior="removed",
+ invert=True,
+ ),
+ pre_tokenizers.ByteLevel(
+ add_prefix_space=self.add_prefix_space
+ ),
+ ]
+ )
+ self._tokenizer.decoder = decoders.ByteLevel()
+
+ # Dummy attrs for serialization compatibility.
+ if not hasattr(self, "cache"):
+ self.byte2unicode = None
+ self.unicode2byte = None
+ self.cache = None
+ self.id_to_token_map = None
+ self.token_to_id_map = None
+ self.merge_ranks_lookup_default = None
+ self.merge_ranks = None
+
def set_vocabulary_and_merges(self, vocabulary, merges):
super().set_vocabulary_and_merges(vocabulary, merges)
if self.pad_with_end_token:
self.pad_token_id = self.end_token_id
+ if getattr(self, "_tokenizer") is not None:
+ self._tokenizer.post_processor = processors.RobertaProcessing(
+ sep=(str(self.end_token), self.end_token_id),
+ cls=(str(self.start_token), self.start_token_id),
+ add_prefix_space=False,
+ trim_offsets=False,
+ )
- def _bpe_merge_and_update_cache(self, tokens):
+ def _bpe_merge_and_update_cache_tf(self, tokens):
"""Process unseen tokens and add to cache."""
- words = self._transform_bytes(tokens)
+
+ def _transform_bytes(tokens):
+ """Map token bytes to unicode using `byte2unicode`."""
+ split_bytes = tf.strings.bytes_split(tokens)
+ split_unicode = self.byte2unicode.lookup(split_bytes)
+ return split_unicode
+
+ words = _transform_bytes(tokens)
# In CLIP, we need to add `` to the last word.
words = tf.strings.reduce_join(words, axis=1, separator=" ")
words = tf.strings.join([words, ""])
words = tf.strings.split(words, sep=" ")
- tokenized_words = self._bpe_merge(words)
+ tokenized_words = self._bpe_merge_tf(words)
# For each word, join all its token by a whitespace,
# e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
@@ -102,11 +184,12 @@ def _bpe_merge_and_update_cache(self, tokens):
self.cache.insert(tokens, tokenized_words)
@preprocessing_function
- def tokenize(self, inputs):
- self._check_vocabulary()
+ def _tokenize_tf(self, inputs):
+ self._maybe_initialized_tf()
if self.add_prefix_space:
inputs = tf.strings.join([" ", inputs])
+ inputs = tf.convert_to_tensor(inputs)
unbatched = inputs.shape.rank == 0
if unbatched:
inputs = tf.expand_dims(inputs, 0)
@@ -121,21 +204,19 @@ def tokenize(self, inputs):
# Strip and remove empty tokens.
raw_tokens = tf.strings.strip(raw_tokens)
raw_tokens = tf.ragged.boolean_mask(raw_tokens, raw_tokens != "")
-
token_row_splits = raw_tokens.row_splits
flat_tokens = raw_tokens.flat_values
# Check cache.
cache_lookup = self.cache.lookup(flat_tokens)
cache_mask = cache_lookup == ""
-
has_unseen_words = tf.math.reduce_any(
(cache_lookup == "") & (flat_tokens != "")
)
def process_unseen_tokens():
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
- self._bpe_merge_and_update_cache(unseen_tokens)
+ self._bpe_merge_and_update_cache_tf(unseen_tokens)
return self.cache.lookup(flat_tokens)
# If `has_unseen_words == True`, it means not all tokens are in cache,
@@ -145,7 +226,6 @@ def process_unseen_tokens():
process_unseen_tokens,
lambda: cache_lookup,
)
-
tokens = tf.strings.split(tokenized_words, sep=" ")
if self.compute_dtype != tf.string:
# Encode merged tokens.
@@ -167,11 +247,24 @@ def process_unseen_tokens():
if unbatched:
tokens = tf.squeeze(tokens, 0)
tf.ensure_shape(tokens, shape=[self.sequence_length])
-
return tokens
+ def _tokenize_tokenizers(self, inputs):
+ outputs = super()._tokenize_tokenizers(inputs)
+ is_batched = True
+ if isinstance(outputs, str):
+ is_batched = False
+ outputs = [outputs]
+ elif isinstance(outputs, list) and isinstance(outputs[0], int):
+ is_batched = False
+ outputs = [outputs]
+ outputs = [output[1:-1] for output in outputs]
+ if not is_batched:
+ outputs = outputs[0]
+ return outputs
+
@preprocessing_function
- def detokenize(self, inputs):
+ def _detokenize_tf(self, inputs):
self._check_vocabulary()
inputs, unbatched, _ = convert_to_ragged_batch(inputs)
inputs = tf.cast(inputs, self.dtype)
@@ -192,6 +285,24 @@ def detokenize(self, inputs):
outputs = tf.squeeze(outputs, 0)
return outputs
+ def _detokenize_tokenizers(self, inputs):
+ outputs = super()._detokenize_tokenizers(inputs)
+
+ def _remove_special_token(inputs):
+ is_batched = True
+ if isinstance(inputs, str):
+ inputs = [inputs]
+ is_batched = False
+ for special_token in ("", self.start_token, self.end_token):
+ inputs = [
+ input.replace(str(special_token), "") for input in inputs
+ ]
+ if not is_batched:
+ inputs = inputs[0]
+ return inputs
+
+ return _remove_special_token(outputs)
+
def get_config(self):
config = super().get_config()
config.update(
diff --git a/keras_hub/src/models/clip/clip_tokenizer_test.py b/keras_hub/src/models/clip/clip_tokenizer_test.py
index bb707dad7e..ee50c55608 100644
--- a/keras_hub/src/models/clip/clip_tokenizer_test.py
+++ b/keras_hub/src/models/clip/clip_tokenizer_test.py
@@ -6,12 +6,16 @@
class CLIPTokenizerTest(TestCase):
def setUp(self):
- vocab = ["air", "plane", "port"]
- vocab += ["<|endoftext|>", "<|startoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(vocab)])
merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"]
merges += ["po rt", "pla ne"]
self.merges = merges
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["<|endoftext|>", "<|startoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = ["airplane ", " airport"]
@@ -21,7 +25,7 @@ def test_tokenizer_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
# Whitespaces should be removed.
- expected_output=[[0, 1], [0, 2]],
+ expected_output=[[4, 14], [4, 16]],
expected_detokenize_output=["airplane", "airport"],
)
@@ -52,3 +56,9 @@ def test_all_presets(self):
preset=preset,
input_data=self.input_data,
)
+
+
+class CLIPTokenizerDisallowPythonWorkflowTest(CLIPTokenizerTest):
+ def setUp(self):
+ super().setUp()
+ self.init_kwargs.update({"_allow_python_workflow": False})
diff --git a/keras_hub/src/models/falcon/falcon_causal_lm_preprocessor_test.py b/keras_hub/src/models/falcon/falcon_causal_lm_preprocessor_test.py
index 42fd3e1206..20d8b0ad72 100644
--- a/keras_hub/src/models/falcon/falcon_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/falcon/falcon_causal_lm_preprocessor_test.py
@@ -9,12 +9,16 @@
class FalconCausalLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = FalconTokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -32,10 +36,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]],
+ "token_ids": [[1, 4, 16, 26, 25, 18, 1, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
- [[1, 3, 4, 2, 5, 6, 0, 0]], # Pass through labels.
+ [[4, 16, 26, 25, 18, 1, 0, 0]], # Pass through labels.
[[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights.
),
)
@@ -49,22 +53,22 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[4, 16, 26, 25, 18, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
- self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0]] * 4)
+ self.assertAllEqual(y, [[16, 26, 25, 18, 0, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0])
+ self.assertAllEqual(x["token_ids"], [1, 4, 16, 26, 25, 18, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": [6, 1, 3, 4, 2, 5, 0, 0],
- "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
+ "token_ids": [1, 4, 16, 26, 25, 18, 1, 0],
+ "padding_mask": [1, 1, 1, 1, 1, 1, 1, 0],
}
preprocessor = FalconCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
diff --git a/keras_hub/src/models/falcon/falcon_causal_lm_test.py b/keras_hub/src/models/falcon/falcon_causal_lm_test.py
index b8b5c9c026..3e0344bf27 100644
--- a/keras_hub/src/models/falcon/falcon_causal_lm_test.py
+++ b/keras_hub/src/models/falcon/falcon_causal_lm_test.py
@@ -14,12 +14,16 @@
class FalconCausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = FalconTokenizer(
vocabulary=self.vocab, merges=self.merges
)
@@ -27,8 +31,9 @@ def setUp(self):
self.tokenizer,
sequence_length=8,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = FalconBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_attention_heads=2,
hidden_dim=4,
@@ -47,12 +52,11 @@ def setUp(self):
self.input_data = self.preprocessor(*self.train_data)[0]
def test_causal_lm_basics(self):
- vocabulary_size = self.tokenizer.vocabulary_size()
self.run_task_test(
cls=FalconCausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 8, vocabulary_size),
+ expected_output_shape=(2, 8, self.vocabulary_size),
)
def test_generate(self):
diff --git a/keras_hub/src/models/falcon/falcon_tokenizer_test.py b/keras_hub/src/models/falcon/falcon_tokenizer_test.py
index 2c10284af8..02645a45ae 100644
--- a/keras_hub/src/models/falcon/falcon_tokenizer_test.py
+++ b/keras_hub/src/models/falcon/falcon_tokenizer_test.py
@@ -6,12 +6,16 @@
class FalconTokenizerTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
" airplane at airport<|endoftext|>",
@@ -23,7 +27,14 @@ def test_tokenizer_basics(self):
cls=FalconTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- expected_output=[[2, 3, 4, 2, 5, 6], [2, 3, 2, 5]],
+ expected_output=[
+ [25, 16, 26, 25, 18, 1],
+ [25, 16, 25, 18],
+ ],
+ expected_detokenize_output=[
+ " airplane at airport<|endoftext|>",
+ " airplane airport",
+ ],
)
def test_errors_missing_special_tokens(self):
diff --git a/keras_hub/src/models/flux/flux_text_to_image_preprocessor_test.py b/keras_hub/src/models/flux/flux_text_to_image_preprocessor_test.py
index d9a3a9d0a8..8af2964633 100644
--- a/keras_hub/src/models/flux/flux_text_to_image_preprocessor_test.py
+++ b/keras_hub/src/models/flux/flux_text_to_image_preprocessor_test.py
@@ -10,11 +10,15 @@
class FluxTextToImagePreprocessorTest(TestCase):
def setUp(self):
- vocab = ["air", "plane", "port"]
- vocab += ["<|endoftext|>", "<|startoftext|>"]
- vocab = dict([(token, i) for i, token in enumerate(vocab)])
merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"]
merges += ["po rt", "pla ne"]
+ vocab = []
+ for merge in merges:
+ a, b = merge.split(" ")
+ vocab.extend([a, b, a + b])
+ vocab += ["<|endoftext|>", "<|startoftext|>"]
+ vocab = sorted(set(vocab)) # Remove duplicates
+ vocab = dict([(token, i) for i, token in enumerate(vocab)])
clip_l_tokenizer = CLIPTokenizer(
vocabulary=vocab, merges=merges, pad_with_end_token=True
)
@@ -48,4 +52,4 @@ def test_generate_preprocess(self):
preprocessor = FluxTextToImagePreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(self.input_data)
self.assertIn("clip_l", x)
- self.assertAllEqual(x["clip_l"][0], [4, 0, 1, 3, 3, 3, 3, 3])
+ self.assertAllEqual(x["clip_l"][0], [1, 4, 14, 0, 0, 0, 0, 0])
diff --git a/keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor_test.py b/keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor_test.py
index 9b0a159356..c63c425b04 100644
--- a/keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/gpt2/gpt2_causal_lm_preprocessor_test.py
@@ -9,12 +9,16 @@
class GPT2CausalLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = GPT2Tokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -32,10 +36,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]],
+ "token_ids": [[1, 4, 16, 26, 25, 18, 1, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
- [[1, 3, 4, 2, 5, 6, 0, 0]], # Pass through labels.
+ [[4, 16, 26, 25, 18, 1, 0, 0]], # Pass through labels.
[[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights.
),
)
@@ -49,22 +53,22 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[4, 16, 26, 25, 18, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
- self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0]] * 4)
+ self.assertAllEqual(y, [[16, 26, 25, 18, 0, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = GPT2CausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0])
+ self.assertAllEqual(x["token_ids"], [1, 4, 16, 26, 25, 18, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": [6, 1, 3, 4, 2, 5, 0, 0],
- "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
+ "token_ids": [1, 4, 16, 26, 25, 18, 1, 0],
+ "padding_mask": [1, 1, 1, 1, 1, 1, 1, 0],
}
preprocessor = GPT2CausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
diff --git a/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py b/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py
index 91be917c01..23af557f07 100644
--- a/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py
+++ b/keras_hub/src/models/gpt2/gpt2_causal_lm_test.py
@@ -15,18 +15,23 @@
class GPT2CausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = GPT2CausalLMPreprocessor(
GPT2Tokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=8,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = GPT2Backbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_heads=2,
hidden_dim=4,
@@ -45,7 +50,7 @@ def test_causal_lm_basics(self):
cls=GPT2CausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 8, 7),
+ expected_output_shape=(2, 8, self.vocabulary_size),
)
def test_generate(self):
@@ -145,7 +150,7 @@ def test_score_logits(self):
# Setup prompts, models, and associated expected shapes.
prompts = [" airplane at airport", " airplane at airport"]
causal_lm = GPT2CausalLM(**self.init_kwargs)
- expected_score_shape = (2, 8, 7)
+ expected_score_shape = (2, 8, self.vocabulary_size)
# Preprocess prompts to get tokenized representations and padding masks.
preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(
@@ -192,7 +197,7 @@ def test_score_layer_intercept_fn_exfiltration(self):
prompts = [" airplane at airport", " airplane at airport"]
causal_lm = GPT2CausalLM(**self.init_kwargs)
expected_embedded_shape = (2, 8, 4)
- expected_score_shape = (2, 8, 7)
+ expected_score_shape = (2, 8, self.vocabulary_size)
# Preprocess prompts to get tokenized representations and padding masks.
preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(
diff --git a/keras_hub/src/models/gpt2/gpt2_preprocessor_test.py b/keras_hub/src/models/gpt2/gpt2_preprocessor_test.py
index a116e04374..0c6df3a650 100644
--- a/keras_hub/src/models/gpt2/gpt2_preprocessor_test.py
+++ b/keras_hub/src/models/gpt2/gpt2_preprocessor_test.py
@@ -7,12 +7,16 @@
class GPT2PreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = GPT2Tokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -29,7 +33,7 @@ def test_preprocessor_basics(self):
init_kwargs=self.init_kwargs,
input_data=self.input_data,
expected_output={
- "token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]],
+ "token_ids": [[1, 4, 16, 26, 25, 18, 1, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
)
@@ -47,14 +51,14 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[4, 16, 26, 25, 18, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
def test_sequence_length_override(self):
input_data = "airplane at airport"
preprocessor = GPT2Preprocessor(**self.init_kwargs)
x = preprocessor(input_data, sequence_length=4)
- self.assertAllEqual(x["token_ids"], [6, 1, 3, 6])
+ self.assertAllEqual(x["token_ids"], [1, 4, 16, 1])
@pytest.mark.extra_large
def test_all_presets(self):
diff --git a/keras_hub/src/models/gpt2/gpt2_tokenizer_test.py b/keras_hub/src/models/gpt2/gpt2_tokenizer_test.py
index 0b887cf914..6b3dd2017a 100644
--- a/keras_hub/src/models/gpt2/gpt2_tokenizer_test.py
+++ b/keras_hub/src/models/gpt2/gpt2_tokenizer_test.py
@@ -6,12 +6,16 @@
class GPT2TokenizerTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
" airplane at airport<|endoftext|>",
@@ -23,7 +27,14 @@ def test_tokenizer_basics(self):
cls=GPT2Tokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- expected_output=[[2, 3, 4, 2, 5, 6], [2, 3, 2, 5]],
+ expected_output=[
+ [23, 14, 24, 23, 16, 30],
+ [23, 14, 23, 16],
+ ],
+ expected_detokenize_output=[
+ " airplane at airport<|endoftext|>",
+ " airplane airport",
+ ],
)
def test_errors_missing_special_tokens(self):
diff --git a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py
index 769dc50260..91f9e685ba 100644
--- a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_preprocessor_test.py
@@ -1,5 +1,3 @@
-from keras import ops
-
from keras_hub.src.models.gpt_neo_x.gpt_neo_x_causal_lm_preprocessor import (
GPTNeoXCausalLMPreprocessor,
)
@@ -9,12 +7,16 @@
class GPTNeoXCausalLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = GPTNeoXTokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -32,10 +34,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[6, 1, 3, 4, 2, 5, 6, 0]],
+ "token_ids": [[1, 4, 16, 26, 25, 18, 1, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
- [[1, 3, 4, 2, 5, 6, 0, 0]], # Pass through labels.
+ [[4, 16, 26, 25, 18, 1, 0, 0]], # Pass through labels.
[[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights.
),
)
@@ -49,22 +51,22 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[4, 16, 26, 25, 18, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
- self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0]] * 4)
+ self.assertAllEqual(y, [[16, 26, 25, 18, 0, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = GPTNeoXCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0])
+ self.assertAllEqual(x["token_ids"], [1, 4, 16, 26, 25, 18, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": ops.array([6, 1, 3, 4, 2, 5, 0, 0]),
- "padding_mask": ops.array([1, 1, 1, 1, 1, 1, 0, 0], dtype="bool"),
+ "token_ids": [1, 4, 16, 26, 25, 18, 1, 0],
+ "padding_mask": [1, 1, 1, 1, 1, 1, 1, 0],
}
preprocessor = GPTNeoXCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
diff --git a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py
index 45d6214680..3ef3d4073b 100644
--- a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py
+++ b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_causal_lm_test.py
@@ -14,18 +14,23 @@
class GPTNeoXCausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = GPTNeoXCausalLMPreprocessor(
GPTNeoXTokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=8,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = GPTNeoXBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_heads=2,
hidden_dim=4,
@@ -44,7 +49,7 @@ def test_causal_lm_basics(self):
cls=GPTNeoXCausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 8, 7),
+ expected_output_shape=(2, 8, self.vocabulary_size),
)
def test_generate(self):
diff --git a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py
index 18ae370cf4..de3876602d 100644
--- a/keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py
+++ b/keras_hub/src/models/gpt_neo_x/gpt_neo_x_tokenizer_test.py
@@ -4,12 +4,16 @@
class GPTNeoXTokenizerTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
" airplane at airport<|endoftext|>",
@@ -21,7 +25,14 @@ def test_tokenizer_basics(self):
cls=GPTNeoXTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- expected_output=[[2, 3, 4, 2, 5, 6], [2, 3, 2, 5]],
+ expected_output=[
+ [25, 16, 26, 25, 18, 1],
+ [25, 16, 25, 18],
+ ],
+ expected_detokenize_output=[
+ " airplane at airport<|endoftext|>",
+ " airplane airport",
+ ],
)
def test_errors_missing_special_tokens(self):
diff --git a/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor_test.py b/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor_test.py
index 2f3c9e9db7..bc894d7b38 100644
--- a/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_preprocessor_test.py
@@ -10,12 +10,16 @@
class GptOssCausalLMPreprocessorTest(TestCase):
def setUp(self):
# Define vocabulary and merges inline like GPT-2 tests
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|startoftext|>", "<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab += ["!", "<|startoftext|>", "<|endoftext|>"]
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = GptOssTokenizer(
vocabulary=self.vocab, merges=self.merges
)
@@ -32,10 +36,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[1, 3, 4, 2, 5, 7, 0, 0]],
+ "token_ids": [[2, 14, 24, 23, 16, 31, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
},
- [[3, 4, 2, 5, 7, 0, 0, 0]], # Pass through labels.
+ [[14, 24, 23, 16, 31, 0, 0, 0]], # Pass through labels.
[[1, 1, 1, 1, 1, 0, 0, 0]], # Pass through sample_weights.
),
)
@@ -49,24 +53,22 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
- # `[3, 8, 4, 6]` -> ` the quick brown fox`
- self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 0, 0, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[2, 14, 24, 23, 16, 0, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
- self.assertAllEqual(y, [[3, 4, 2, 5, 0, 0, 0, 0]] * 4)
+ self.assertAllEqual(y, [[14, 24, 23, 16, 0, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = GptOssCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- # `[1, 3, 8, 4, 6]` -> ` the quick brown fox`
# `generate_preprocess` should not add an end token.
- self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 0, 0, 0])
+ self.assertAllEqual(x["token_ids"], [2, 14, 24, 23, 16, 0, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": [1, 3, 4, 2, 5, 7, 7, 7],
+ "token_ids": [2, 14, 24, 23, 16, 0, 0, 0],
"padding_mask": [1, 1, 1, 1, 1, 0, 0, 0],
}
preprocessor = GptOssCausalLMPreprocessor(**self.init_kwargs)
diff --git a/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_test.py b/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_test.py
index 3968af58d3..50ff726608 100644
--- a/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_test.py
+++ b/keras_hub/src/models/gpt_oss/gpt_oss_causal_lm_test.py
@@ -15,18 +15,23 @@
class GptOssCausalLMTest(TestCase):
def setUp(self):
# Define vocabulary and merges inline like GPT-2 tests
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|startoftext|>", "<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["<|endoftext|>", "!"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = GptOssCausalLMPreprocessor(
GptOssTokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=8,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = GptOssBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_query_heads=4,
num_key_value_heads=2,
@@ -46,7 +51,7 @@ def test_causal_lm_basics(self):
cls=GptOssCausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 8, 8),
+ expected_output_shape=(2, 8, self.vocabulary_size),
)
def test_generate(self):
@@ -128,7 +133,7 @@ def test_score_logits(self):
# Setup prompts, models, and associated expected shapes.
prompts = [" airplane at airport", " airplane"]
causal_lm = GptOssCausalLM(**self.init_kwargs)
- expected_score_shape = (2, 8, 8)
+ expected_score_shape = (2, 8, self.vocabulary_size)
# Preprocess prompts to get tokenized representations and padding masks.
preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(
@@ -175,7 +180,7 @@ def test_score_layer_intercept_fn_exfiltration(self):
prompts = [" airplane at airport", " airplane"]
causal_lm = GptOssCausalLM(**self.init_kwargs)
expected_embedded_shape = (2, 8, 8)
- expected_score_shape = (2, 8, 8)
+ expected_score_shape = (2, 8, self.vocabulary_size)
# Preprocess prompts to get tokenized representations and padding masks.
preprocessed_prompts = causal_lm.preprocessor.generate_preprocess(
diff --git a/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py b/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py
index f79be674fb..fb5bcbc5df 100644
--- a/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/llama3/llama3_causal_lm_preprocessor_test.py
@@ -9,14 +9,18 @@
class Llama3CausalLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
- self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
- self.vocab += ["<|eot_id|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|end_of_text|>", "<|begin_of_text|>"]
+ self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
+ self.vocab += ["<|eot_id|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = Llama3Tokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -34,10 +38,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[6, 1, 3, 4, 2, 5, 7, 0]],
+ "token_ids": [[1, 8, 20, 30, 29, 22, 3, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
- [[1, 3, 4, 2, 5, 7, 0, 0]],
+ [[8, 20, 30, 29, 22, 3, 0, 0]],
[[1, 1, 1, 1, 1, 1, 0, 0]],
),
)
@@ -51,21 +55,21 @@ def test_with_start_end_token(self):
add_end_token=True,
)
x, y, sw = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[6, 1, 3, 4, 2, 5, 7, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[1, 8, 20, 30, 29, 22, 3, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 1, 0]] * 4)
- self.assertAllEqual(y, [[1, 3, 4, 2, 5, 7, 0, 0]] * 4)
+ self.assertAllEqual(y, [[8, 20, 30, 29, 22, 3, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = Llama3CausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- self.assertAllEqual(x["token_ids"], [6, 1, 3, 4, 2, 5, 0, 0])
+ self.assertAllEqual(x["token_ids"], [1, 8, 20, 30, 29, 22, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": [6, 1, 3, 4, 2, 5, 0, 0],
+ "token_ids": [1, 8, 20, 30, 29, 22, 0, 0],
"padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
}
preprocessor = Llama3CausalLMPreprocessor(**self.init_kwargs)
diff --git a/keras_hub/src/models/llama3/llama3_causal_lm_test.py b/keras_hub/src/models/llama3/llama3_causal_lm_test.py
index 47ff516103..8208752003 100644
--- a/keras_hub/src/models/llama3/llama3_causal_lm_test.py
+++ b/keras_hub/src/models/llama3/llama3_causal_lm_test.py
@@ -14,20 +14,25 @@
class Llama3CausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|begin_of_text|>", "<|end_of_text|>"]
- self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
- self.vocab += ["<|eot_id|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|end_of_text|>", "<|begin_of_text|>"]
+ self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
+ self.vocab += ["<|eot_id|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = Llama3CausalLMPreprocessor(
Llama3Tokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=7,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = Llama3Backbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_query_heads=4,
num_key_value_heads=2,
@@ -46,7 +51,7 @@ def test_causal_lm_basics(self):
cls=Llama3CausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 7, 11),
+ expected_output_shape=(2, 7, self.vocabulary_size),
)
def test_generate(self):
diff --git a/keras_hub/src/models/llama3/llama3_tokenizer_test.py b/keras_hub/src/models/llama3/llama3_tokenizer_test.py
index a6b50530ba..6387bbc15b 100644
--- a/keras_hub/src/models/llama3/llama3_tokenizer_test.py
+++ b/keras_hub/src/models/llama3/llama3_tokenizer_test.py
@@ -6,14 +6,17 @@
class Llama3TokenizerTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|end_of_text|>", "<|begin_of_text|>"]
- self.vocab += ["<|start_header_id|>", "<|end_header_id|>"]
- self.vocab += ["<|eot_id|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|end_of_text|>", "<|begin_of_text|>"]
+ self.vocab += ["<|start_header_id|>", "<|end_header_id|>", "<|eot_id|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
"<|begin_of_text|>airplane at airport<|end_of_text|>",
@@ -25,7 +28,14 @@ def test_tokenizer_basics(self):
cls=Llama3Tokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- expected_output=[[7, 1, 3, 4, 2, 5, 6], [2, 3, 2, 5]],
+ expected_output=[
+ [1, 8, 20, 30, 29, 22, 3],
+ [29, 20, 29, 22],
+ ],
+ expected_detokenize_output=[
+ "<|begin_of_text|>airplane at airport<|end_of_text|>",
+ " airplane airport",
+ ],
)
def test_errors_missing_special_tokens(self):
diff --git a/keras_hub/src/models/masked_lm_preprocessor.py b/keras_hub/src/models/masked_lm_preprocessor.py
index d09b3a200e..f5bba40b04 100644
--- a/keras_hub/src/models/masked_lm_preprocessor.py
+++ b/keras_hub/src/models/masked_lm_preprocessor.py
@@ -79,6 +79,10 @@ def __init__(
self.random_token_rate = random_token_rate
self.masker = None
+ # TODO(hongyu): Since `MultiSegmentPacker` requires TF workflow, we
+ # currently disable the Python workflow for `MaskedLMPreprocessor`.
+ self.tokenizer._allow_python_workflow = False
+
def build(self, input_shape):
super().build(input_shape)
# Defer masker creation to `build()` so that we can be sure tokenizer
diff --git a/keras_hub/src/models/opt/opt_causal_lm_preprocessor_test.py b/keras_hub/src/models/opt/opt_causal_lm_preprocessor_test.py
index bfbf5b2640..7f1336c938 100644
--- a/keras_hub/src/models/opt/opt_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/opt/opt_causal_lm_preprocessor_test.py
@@ -9,11 +9,16 @@
class OPTCausalLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["", "", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = OPTTokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -31,10 +36,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[1, 2, 4, 5, 3, 6, 1, 0]],
+ "token_ids": [[0, 4, 16, 26, 25, 18, 0, 1]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
- [[2, 4, 5, 3, 6, 1, 0, 0]], # Pass through labels.
+ [[4, 16, 26, 25, 18, 0, 1, 1]], # Pass through labels.
[[1, 1, 1, 1, 1, 1, 0, 0]], # Pass through sample_weights.
),
)
@@ -48,22 +53,22 @@ def test_no_start_end_token(self):
add_end_token=False,
)
x, y, sw = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[2, 4, 5, 3, 6, 0, 0, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[4, 16, 26, 25, 18, 1, 1, 1]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
- self.assertAllEqual(y, [[4, 5, 3, 6, 0, 0, 0, 0]] * 4)
+ self.assertAllEqual(y, [[16, 26, 25, 18, 1, 1, 1, 1]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 0, 0, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = OPTCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- self.assertAllEqual(x["token_ids"], [1, 2, 4, 5, 3, 6, 0, 0])
+ self.assertAllEqual(x["token_ids"], [0, 4, 16, 26, 25, 18, 1, 1])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 1, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": [1, 2, 4, 5, 3, 6, 0, 0],
- "padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
+ "token_ids": [0, 4, 16, 26, 25, 18, 0, 1],
+ "padding_mask": [1, 1, 1, 1, 1, 1, 1, 0],
}
preprocessor = OPTCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_postprocess(input_data)
diff --git a/keras_hub/src/models/opt/opt_causal_lm_test.py b/keras_hub/src/models/opt/opt_causal_lm_test.py
index 576e777a94..f024c8b1b2 100644
--- a/keras_hub/src/models/opt/opt_causal_lm_test.py
+++ b/keras_hub/src/models/opt/opt_causal_lm_test.py
@@ -14,17 +14,23 @@
class OPTCausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["", "", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = OPTCausalLMPreprocessor(
OPTTokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=8,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = OPTBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_heads=2,
hidden_dim=4,
@@ -43,7 +49,7 @@ def test_causal_lm_basics(self):
cls=OPTCausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 8, 7),
+ expected_output_shape=(2, 8, self.vocabulary_size),
)
def test_generate(self):
diff --git a/keras_hub/src/models/opt/opt_tokenizer_test.py b/keras_hub/src/models/opt/opt_tokenizer_test.py
index 357516fd13..92e611b1ff 100644
--- a/keras_hub/src/models/opt/opt_tokenizer_test.py
+++ b/keras_hub/src/models/opt/opt_tokenizer_test.py
@@ -6,11 +6,16 @@
class OPTTokenizerTest(TestCase):
def setUp(self):
- self.vocab = ["", "", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab += ["", ""]
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
" airplane at airport",
@@ -22,7 +27,14 @@ def test_tokenizer_basics(self):
cls=OPTTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- expected_output=[[3, 4, 5, 3, 6, 1], [3, 4, 3, 6]],
+ expected_output=[
+ [23, 14, 24, 23, 16, 30],
+ [23, 14, 23, 16],
+ ],
+ expected_detokenize_output=[
+ " airplane at airport",
+ " airplane airport",
+ ],
)
def test_errors_missing_special_tokens(self):
diff --git a/keras_hub/src/models/qwen/qwen_causal_lm_test.py b/keras_hub/src/models/qwen/qwen_causal_lm_test.py
index 7cddd4a714..aa53157db8 100644
--- a/keras_hub/src/models/qwen/qwen_causal_lm_test.py
+++ b/keras_hub/src/models/qwen/qwen_causal_lm_test.py
@@ -14,19 +14,23 @@
class QwenCausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab += ["<|eot_id|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["!", "<|endoftext|>", "<|eot_id|>"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = QwenCausalLMPreprocessor(
QwenTokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=7,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = QwenBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_query_heads=4,
num_key_value_heads=2,
@@ -45,7 +49,7 @@ def test_causal_lm_basics(self):
cls=QwenCausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 7, 8),
+ expected_output_shape=(2, 7, self.vocabulary_size),
)
def test_generate(self):
diff --git a/keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor_test.py b/keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor_test.py
index abbbc9a2bc..a03d10445c 100644
--- a/keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/qwen3/qwen3_causal_lm_preprocessor_test.py
@@ -7,12 +7,16 @@
class Qwen3CausalLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|im_end|>", "<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["<|im_end|>", "<|endoftext|>", "!"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = Qwen3Tokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -30,10 +34,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[1, 3, 4, 2, 5, 6, 7, 7]],
+ "token_ids": [[5, 17, 27, 26, 19, 2, 1, 1]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
},
- [[3, 4, 2, 5, 6, 7, 7, 7]],
+ [[17, 27, 26, 19, 2, 1, 1, 1]],
[[1, 1, 1, 1, 1, 0, 0, 0]],
),
)
@@ -46,21 +50,21 @@ def test_with_start_end_token(self):
add_end_token=True,
)
x, y, sw = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 6, 7, 7]] * 4)
+ self.assertAllEqual(x["token_ids"], [[5, 17, 27, 26, 19, 2, 1, 1]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)
- self.assertAllEqual(y, [[3, 4, 2, 5, 6, 7, 7, 7]] * 4)
+ self.assertAllEqual(y, [[17, 27, 26, 19, 2, 1, 1, 1]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = Qwen3CausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 7, 7, 7])
+ self.assertAllEqual(x["token_ids"], [5, 17, 27, 26, 19, 1, 1, 1])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": [1, 3, 4, 2, 5, 7, 7, 7],
+ "token_ids": [5, 17, 27, 26, 19, 1, 1, 1],
"padding_mask": [1, 1, 1, 1, 1, 0, 0, 0],
}
preprocessor = Qwen3CausalLMPreprocessor(**self.init_kwargs)
diff --git a/keras_hub/src/models/qwen3/qwen3_causal_lm_test.py b/keras_hub/src/models/qwen3/qwen3_causal_lm_test.py
index d7d8758507..3805671a3b 100644
--- a/keras_hub/src/models/qwen3/qwen3_causal_lm_test.py
+++ b/keras_hub/src/models/qwen3/qwen3_causal_lm_test.py
@@ -14,19 +14,23 @@
class Qwen3CausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab += ["<|im_end|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["<|im_end|>", "<|endoftext|>", "!"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = Qwen3CausalLMPreprocessor(
Qwen3Tokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=7,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = Qwen3Backbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_query_heads=4,
num_key_value_heads=2,
@@ -46,7 +50,7 @@ def test_causal_lm_basics(self):
cls=Qwen3CausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 7, 8),
+ expected_output_shape=(2, 7, self.vocabulary_size),
)
def test_generate(self):
diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor_test.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor_test.py
index 180c5f64ed..8a48476ca0 100644
--- a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_preprocessor_test.py
@@ -7,12 +7,16 @@
class Qwen3MoeCausalLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|im_end|>", "<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["<|im_end|>", "<|endoftext|>", "!"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = Qwen3MoeTokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -30,10 +34,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[1, 3, 4, 2, 5, 6, 7, 7]],
+ "token_ids": [[5, 17, 27, 26, 19, 2, 1, 1]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
},
- [[3, 4, 2, 5, 6, 7, 7, 7]],
+ [[17, 27, 26, 19, 2, 1, 1, 1]],
[[1, 1, 1, 1, 1, 0, 0, 0]],
),
)
@@ -46,21 +50,21 @@ def test_with_start_end_token(self):
add_end_token=True,
)
x, y, sw = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 6, 7, 7]] * 4)
+ self.assertAllEqual(x["token_ids"], [[5, 17, 27, 26, 19, 2, 1, 1]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)
- self.assertAllEqual(y, [[3, 4, 2, 5, 6, 7, 7, 7]] * 4)
+ self.assertAllEqual(y, [[17, 27, 26, 19, 2, 1, 1, 1]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = Qwen3MoeCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 7, 7, 7])
+ self.assertAllEqual(x["token_ids"], [5, 17, 27, 26, 19, 1, 1, 1])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": [1, 3, 4, 2, 5, 7, 7, 7],
+ "token_ids": [5, 17, 27, 26, 19, 1, 1, 1],
"padding_mask": [1, 1, 1, 1, 1, 0, 0, 0],
}
preprocessor = Qwen3MoeCausalLMPreprocessor(**self.init_kwargs)
diff --git a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py
index 94f0bb2fd5..698fc08ad6 100644
--- a/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py
+++ b/keras_hub/src/models/qwen3_moe/qwen3_moe_causal_lm_test.py
@@ -17,19 +17,23 @@
class Qwen3MoeCausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab += ["<|im_end|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["<|im_end|>", "<|endoftext|>", "!"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = Qwen3MoeCausalLMPreprocessor(
Qwen3MoeTokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=7,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = Qwen3MoeBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_query_heads=4,
num_key_value_heads=2,
@@ -52,7 +56,7 @@ def test_causal_lm_basics(self):
cls=Qwen3MoeCausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 7, 8),
+ expected_output_shape=(2, 7, self.vocabulary_size),
)
def test_generate(self):
diff --git a/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor_test.py b/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor_test.py
index cf52afd244..8890f6db71 100644
--- a/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor_test.py
+++ b/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_preprocessor_test.py
@@ -9,13 +9,16 @@
class QwenMoeCausalLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab += ["<|eot_id|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["<|endoftext|>", "<|eot_id|>", "!"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = QwenMoeTokenizer(
vocabulary=self.vocab,
merges=self.merges,
@@ -33,10 +36,10 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[1, 3, 4, 2, 5, 6, 0, 0]],
+ "token_ids": [[5, 17, 27, 26, 19, 1, 0, 0]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0]],
},
- [[3, 4, 2, 5, 6, 0, 0, 0]],
+ [[17, 27, 26, 19, 1, 0, 0, 0]],
[[1, 1, 1, 1, 1, 0, 0, 0]],
),
)
@@ -49,21 +52,21 @@ def test_with_end_token(self):
add_end_token=True,
)
x, y, sw = preprocessor(input_data)
- self.assertAllEqual(x["token_ids"], [[1, 3, 4, 2, 5, 6, 0, 0]] * 4)
+ self.assertAllEqual(x["token_ids"], [[5, 17, 27, 26, 19, 1, 0, 0]] * 4)
self.assertAllEqual(x["padding_mask"], [[1, 1, 1, 1, 1, 1, 0, 0]] * 4)
- self.assertAllEqual(y, [[3, 4, 2, 5, 6, 0, 0, 0]] * 4)
+ self.assertAllEqual(y, [[17, 27, 26, 19, 1, 0, 0, 0]] * 4)
self.assertAllEqual(sw, [[1, 1, 1, 1, 1, 0, 0, 0]] * 4)
def test_generate_preprocess(self):
input_data = "airplane at airport"
preprocessor = QwenMoeCausalLMPreprocessor(**self.init_kwargs)
x = preprocessor.generate_preprocess(input_data)
- self.assertAllEqual(x["token_ids"], [1, 3, 4, 2, 5, 0, 0, 0])
+ self.assertAllEqual(x["token_ids"], [5, 17, 27, 26, 19, 0, 0, 0])
self.assertAllEqual(x["padding_mask"], [1, 1, 1, 1, 1, 0, 0, 0])
def test_generate_postprocess(self):
input_data = {
- "token_ids": [1, 3, 4, 2, 5, 6, 0, 0],
+ "token_ids": [5, 17, 27, 26, 19, 1, 0, 0],
"padding_mask": [1, 1, 1, 1, 1, 1, 0, 0],
}
preprocessor = QwenMoeCausalLMPreprocessor(**self.init_kwargs)
diff --git a/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_test.py b/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_test.py
index 4947ff3781..20142d36ef 100644
--- a/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_test.py
+++ b/keras_hub/src/models/qwen_moe/qwen_moe_causal_lm_test.py
@@ -21,19 +21,23 @@
class QwenMoeCausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab += ["<|eot_id|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["<|endoftext|>", "<|eot_id|>", "!"]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = QwenMoeCausalLMPreprocessor(
QwenMoeTokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=7,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = QwenMoeBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_query_heads=4,
num_key_value_heads=2,
@@ -55,7 +59,7 @@ def test_causal_lm_basics(self):
cls=QwenMoeCausalLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 7, 8),
+ expected_output_shape=(2, 7, self.vocabulary_size),
)
def test_flash_attention_call(self):
diff --git a/keras_hub/src/models/roberta/roberta_masked_lm_preprocessor_test.py b/keras_hub/src/models/roberta/roberta_masked_lm_preprocessor_test.py
index 378825c53f..10be0635c9 100644
--- a/keras_hub/src/models/roberta/roberta_masked_lm_preprocessor_test.py
+++ b/keras_hub/src/models/roberta/roberta_masked_lm_preprocessor_test.py
@@ -9,12 +9,16 @@
class RobertaMaskedLMPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["", "", "", "air", "Ġair", "plane", "Ġat"]
- self.vocab += ["port", ""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", "", "", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = RobertaTokenizer(
vocabulary=self.vocab, merges=self.merges
)
@@ -25,7 +29,7 @@ def setUp(self):
"mask_token_rate": 1.0,
"random_token_rate": 0.0,
"mask_selection_length": 4,
- "sequence_length": 12,
+ "sequence_length": 10,
}
self.input_data = [" airplane airport"]
@@ -36,11 +40,11 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[0, 8, 8, 8, 8, 2, 1, 1, 1, 1, 1, 1]],
- "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]],
+ "token_ids": [[3, 1, 1, 1, 1, 0, 2, 2, 2, 2]],
+ "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]],
"mask_positions": [[1, 2, 3, 4]],
},
- [[4, 5, 4, 7]],
+ [[27, 18, 27, 20]],
[[1.0, 1.0, 1.0, 1.0]],
),
)
@@ -50,15 +54,15 @@ def test_no_masking_zero_rate(self):
self.tokenizer,
mask_selection_rate=0.0,
mask_selection_length=4,
- sequence_length=12,
+ sequence_length=10,
)
input_data = [" airplane airport"]
self.assertAllClose(
no_mask_preprocessor(input_data),
(
{
- "token_ids": [[0, 4, 5, 4, 7, 2, 1, 1, 1, 1, 1, 1]],
- "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0, 0, 0]],
+ "token_ids": [[3, 27, 18, 27, 20, 0, 2, 2, 2, 2]],
+ "padding_mask": [[1, 1, 1, 1, 1, 1, 0, 0, 0, 0]],
"mask_positions": [[0, 0, 0, 0]],
},
[[0, 0, 0, 0]],
diff --git a/keras_hub/src/models/roberta/roberta_masked_lm_test.py b/keras_hub/src/models/roberta/roberta_masked_lm_test.py
index 4a287895fe..34036f7043 100644
--- a/keras_hub/src/models/roberta/roberta_masked_lm_test.py
+++ b/keras_hub/src/models/roberta/roberta_masked_lm_test.py
@@ -12,12 +12,16 @@
class RobertaMaskedLMTest(TestCase):
def setUp(self):
# Setup model.
- self.vocab = ["", "", "", "air", "Ġair", "plane", "Ġat"]
- self.vocab += ["port", ""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", "", "", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = RobertaMaskedLMPreprocessor(
RobertaTokenizer(vocabulary=self.vocab, merges=self.merges),
# Simplify our testing by masking every available token.
@@ -25,10 +29,11 @@ def setUp(self):
mask_token_rate=1.0,
random_token_rate=0.0,
mask_selection_length=5,
- sequence_length=5,
+ sequence_length=10,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = RobertaBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_heads=2,
hidden_dim=2,
@@ -49,7 +54,7 @@ def test_masked_lm_basics(self):
cls=RobertaMaskedLM,
init_kwargs=self.init_kwargs,
train_data=self.train_data,
- expected_output_shape=(2, 5, 9),
+ expected_output_shape=(2, 5, self.vocabulary_size),
)
@pytest.mark.large
diff --git a/keras_hub/src/models/roberta/roberta_text_classifier_preprocessor_test.py b/keras_hub/src/models/roberta/roberta_text_classifier_preprocessor_test.py
index 8cbb8c0a7e..4ce7f4f11e 100644
--- a/keras_hub/src/models/roberta/roberta_text_classifier_preprocessor_test.py
+++ b/keras_hub/src/models/roberta/roberta_text_classifier_preprocessor_test.py
@@ -9,12 +9,16 @@
class RobertaTextClassifierPreprocessorTest(TestCase):
def setUp(self):
- self.vocab = ["", "", "", "air", "Ġair", "plane", "Ġat"]
- self.vocab += ["port", ""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", "", "", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.tokenizer = RobertaTokenizer(
vocabulary=self.vocab, merges=self.merges
)
@@ -35,7 +39,7 @@ def test_preprocessor_basics(self):
input_data=self.input_data,
expected_output=(
{
- "token_ids": [[0, 4, 5, 6, 4, 7, 2, 1]],
+ "token_ids": [[3, 27, 18, 28, 27, 20, 0, 2]],
"padding_mask": [[1, 1, 1, 1, 1, 1, 1, 0]],
},
[1], # Pass through labels.
diff --git a/keras_hub/src/models/roberta/roberta_text_classifier_test.py b/keras_hub/src/models/roberta/roberta_text_classifier_test.py
index 4c5fdc8bef..80efcca846 100644
--- a/keras_hub/src/models/roberta/roberta_text_classifier_test.py
+++ b/keras_hub/src/models/roberta/roberta_text_classifier_test.py
@@ -14,18 +14,23 @@
class RobertaTextClassifierTest(TestCase):
def setUp(self):
# Setup model.
- self.vocab = ["", "", "", "air", "Ġair", "plane", "Ġat"]
- self.vocab += ["port", ""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += ["", "", "", ""]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = RobertaTextClassifierPreprocessor(
RobertaTokenizer(vocabulary=self.vocab, merges=self.merges),
- sequence_length=5,
+ sequence_length=10,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = RobertaBackbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
num_layers=2,
num_heads=2,
hidden_dim=2,
diff --git a/keras_hub/src/models/roberta/roberta_tokenizer_test.py b/keras_hub/src/models/roberta/roberta_tokenizer_test.py
index 43938a4e06..dba5066fb2 100644
--- a/keras_hub/src/models/roberta/roberta_tokenizer_test.py
+++ b/keras_hub/src/models/roberta/roberta_tokenizer_test.py
@@ -6,12 +6,16 @@
class RobertaTokenizerTest(TestCase):
def setUp(self):
- self.vocab = ["", "", "", "air", "Ġair", "plane", "Ġat"]
- self.vocab += ["port", ""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab += ["", "", "", ""]
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.init_kwargs = {"vocabulary": self.vocab, "merges": self.merges}
self.input_data = [
" airplane at airport",
@@ -23,8 +27,10 @@ def test_tokenizer_basics(self):
cls=RobertaTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- # TODO: should not get tokenized as
- expected_output=[[0, 4, 5, 6, 4, 7, 2, 1], [4, 5, 4, 7]],
+ expected_output=[
+ [29, 23, 14, 24, 23, 16, 31, 30],
+ [23, 14, 23, 16],
+ ],
expected_detokenize_output=[
" airplane at airport",
" airplane airport",
diff --git a/keras_hub/src/models/sam3/sam3_pc_image_segmenter_test.py b/keras_hub/src/models/sam3/sam3_pc_image_segmenter_test.py
index 4bfa3041f8..6f69d2a47b 100644
--- a/keras_hub/src/models/sam3/sam3_pc_image_segmenter_test.py
+++ b/keras_hub/src/models/sam3/sam3_pc_image_segmenter_test.py
@@ -88,18 +88,15 @@ def setUp(self):
crop_to_aspect_ratio=False,
antialias=True,
)
- self.tokenizer = SAM3Tokenizer(
- {
- "!": 0,
- '"': 1,
- "#": 2,
- "$": 3,
- "%": 4,
- "<|endoftext|>": 5,
- "<|startoftext|>": 6,
- },
- ["i n", "t h", "a n"],
- )
+ merges = ["i n", "t h", "a n"]
+ vocab = []
+ for merge in merges:
+ a, b = merge.split(" ")
+ vocab.extend([a, b, a + b])
+ vocab += ["!", '"', "#", "$", "%", "<|endoftext|>", "<|startoftext|>"]
+ vocab = sorted(set(vocab)) # Remove duplicates
+ vocab = dict([(token, i) for i, token in enumerate(vocab)])
+ self.tokenizer = SAM3Tokenizer(vocab, merges)
self.preprocessor = SAM3PromptableConceptImageSegmenterPreprocessor(
self.tokenizer, self.image_converter
)
diff --git a/keras_hub/src/models/seq_2_seq_lm_preprocessor.py b/keras_hub/src/models/seq_2_seq_lm_preprocessor.py
index 9398c6e2d7..72b2656b06 100644
--- a/keras_hub/src/models/seq_2_seq_lm_preprocessor.py
+++ b/keras_hub/src/models/seq_2_seq_lm_preprocessor.py
@@ -82,6 +82,10 @@ def __init__(
self.encoder_sequence_length = encoder_sequence_length
self.decoder_sequence_length = decoder_sequence_length
+ # TODO(hongyu): Since `Seq2SeqLMPreprocessor` requires TF workflow, we
+ # currently disable the Python workflow for `Seq2SeqLMPreprocessor`.
+ self.tokenizer._allow_python_workflow = False
+
def build(self, input_shape):
# Defer packer creation to `build()` so that we can be sure tokenizer
# assets have loaded when restoring a saved model.
diff --git a/keras_hub/src/models/smollm3/smollm3_causal_lm_test.py b/keras_hub/src/models/smollm3/smollm3_causal_lm_test.py
index cf26647c0d..62053c5334 100644
--- a/keras_hub/src/models/smollm3/smollm3_causal_lm_test.py
+++ b/keras_hub/src/models/smollm3/smollm3_causal_lm_test.py
@@ -14,21 +14,29 @@
class SmolLM3CausalLMTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|begin_of_text|>"]
- self.vocab += ["<|end_of_text|>"]
- self.vocab += [""]
- self.vocab += [""]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab += [
+ "<|begin_of_text|>",
+ "<|end_of_text|>",
+ "",
+ "",
+ "!",
+ ]
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.preprocessor = SmolLM3CausalLMPreprocessor(
SmolLM3Tokenizer(vocabulary=self.vocab, merges=self.merges),
sequence_length=8,
)
+ self.vocabulary_size = self.preprocessor.tokenizer.vocabulary_size()
self.backbone = SmolLM3Backbone(
- vocabulary_size=self.preprocessor.tokenizer.vocabulary_size(),
+ vocabulary_size=self.vocabulary_size,
hidden_dim=64,
intermediate_dim=128,
num_layers=2,
diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py
index 3bdf7b647a..5f087eaaf3 100644
--- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py
+++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_image_to_image_test.py
@@ -23,11 +23,15 @@
class StableDiffusion3ImageToImageTest(TestCase):
def setUp(self):
# Instantiate the preprocessor.
- vocab = ["air", "plane", "port"]
- vocab += ["<|endoftext|>", "<|startoftext|>"]
- vocab = dict([(token, i) for i, token in enumerate(vocab)])
merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"]
merges += ["po rt", "pla ne"]
+ vocab = []
+ for merge in merges:
+ a, b = merge.split(" ")
+ vocab.extend([a, b, a + b])
+ vocab += ["<|endoftext|>", "<|startoftext|>"]
+ vocab = sorted(set(vocab)) # Remove duplicates
+ vocab = dict([(token, i) for i, token in enumerate(vocab)])
clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True)
clip_g_tokenizer = CLIPTokenizer(vocab, merges)
clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer)
diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py
index 074cc1429c..6dcb038335 100644
--- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py
+++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_inpaint_test.py
@@ -23,11 +23,15 @@
class StableDiffusion3InpaintTest(TestCase):
def setUp(self):
# Instantiate the preprocessor.
- vocab = ["air", "plane", "port"]
- vocab += ["<|endoftext|>", "<|startoftext|>"]
- vocab = dict([(token, i) for i, token in enumerate(vocab)])
merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"]
merges += ["po rt", "pla ne"]
+ vocab = []
+ for merge in merges:
+ a, b = merge.split(" ")
+ vocab.extend([a, b, a + b])
+ vocab += ["<|endoftext|>", "<|startoftext|>"]
+ vocab = sorted(set(vocab)) # Remove duplicates
+ vocab = dict([(token, i) for i, token in enumerate(vocab)])
clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True)
clip_g_tokenizer = CLIPTokenizer(vocab, merges)
clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer)
diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py
index 46d69be381..48a839661f 100644
--- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py
+++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_preprocessor_test.py
@@ -10,11 +10,15 @@
class StableDiffusion3TextToImagePreprocessorTest(TestCase):
def setUp(self):
- vocab = ["air", "plane", "port"]
- vocab += ["<|endoftext|>", "<|startoftext|>"]
- vocab = dict([(token, i) for i, token in enumerate(vocab)])
merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"]
merges += ["po rt", "pla ne"]
+ vocab = []
+ for merge in merges:
+ a, b = merge.split(" ")
+ vocab.extend([a, b, a + b])
+ vocab = sorted(set(vocab)) # Remove duplicates
+ vocab += ["<|endoftext|>", "<|startoftext|>"]
+ vocab = dict([(token, i) for i, token in enumerate(vocab)])
clip_l_tokenizer = CLIPTokenizer(
vocabulary=vocab, merges=merges, pad_with_end_token=True
)
@@ -56,5 +60,5 @@ def test_generate_preprocess(self):
x = preprocessor.generate_preprocess(self.input_data)
self.assertIn("clip_l", x)
self.assertIn("clip_g", x)
- self.assertAllEqual(x["clip_l"][0], [4, 0, 1, 3, 3, 3, 3, 3])
- self.assertAllEqual(x["clip_g"][0], [4, 0, 1, 3, 0, 0, 0, 0])
+ self.assertAllEqual(x["clip_l"][0], [19, 2, 12, 18, 18, 18, 18, 18])
+ self.assertAllEqual(x["clip_g"][0], [19, 2, 12, 18, 0, 0, 0, 0])
diff --git a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py
index 609146af51..f6c1cd5314 100644
--- a/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py
+++ b/keras_hub/src/models/stable_diffusion_3/stable_diffusion_3_text_to_image_test.py
@@ -23,11 +23,15 @@
class StableDiffusion3TextToImageTest(TestCase):
def setUp(self):
# Instantiate the preprocessor.
- vocab = ["air", "plane", "port"]
- vocab += ["<|endoftext|>", "<|startoftext|>"]
- vocab = dict([(token, i) for i, token in enumerate(vocab)])
merges = ["a i", "p l", "n e", "p o", "r t", "ai r", "pl a"]
merges += ["po rt", "pla ne"]
+ vocab = []
+ for merge in merges:
+ a, b = merge.split(" ")
+ vocab.extend([a, b, a + b])
+ vocab += ["<|endoftext|>", "<|startoftext|>"]
+ vocab = sorted(set(vocab)) # Remove duplicates
+ vocab = dict([(token, i) for i, token in enumerate(vocab)])
clip_l_tokenizer = CLIPTokenizer(vocab, merges, pad_with_end_token=True)
clip_g_tokenizer = CLIPTokenizer(vocab, merges)
clip_l_preprocessor = CLIPPreprocessor(clip_l_tokenizer)
diff --git a/keras_hub/src/models/text_classifier_preprocessor.py b/keras_hub/src/models/text_classifier_preprocessor.py
index 4061d5e940..10151d5220 100644
--- a/keras_hub/src/models/text_classifier_preprocessor.py
+++ b/keras_hub/src/models/text_classifier_preprocessor.py
@@ -79,6 +79,11 @@ def __init__(
self.sequence_length = sequence_length
self.truncate = truncate
+ # TODO(hongyu): Since `MultiSegmentPacker` requires TF workflow, we
+ # currently disable the Python workflow for
+ # `TextClassifierPreprocessor`.
+ self.tokenizer._allow_python_workflow = False
+
def build(self, input_shape):
super().build(input_shape)
# Defer masker creation to `build()` so that we can be sure tokenizer
diff --git a/keras_hub/src/models/v2/causal_lm_preprocessor.py b/keras_hub/src/models/v2/causal_lm_preprocessor.py
index 02877b10e3..e80f34e267 100644
--- a/keras_hub/src/models/v2/causal_lm_preprocessor.py
+++ b/keras_hub/src/models/v2/causal_lm_preprocessor.py
@@ -1,9 +1,7 @@
import keras
from keras_hub.src.api_export import keras_hub_export
-from keras_hub.src.layers.preprocessing.v2.start_end_packer import (
- StartEndPacker,
-)
+from keras_hub.src.layers.preprocessing.start_end_packer import StartEndPacker
from keras_hub.src.models.preprocessor import Preprocessor
diff --git a/keras_hub/src/models/whisper/whisper_tokenizer_test.py b/keras_hub/src/models/whisper/whisper_tokenizer_test.py
index fdeec80124..a133d09538 100644
--- a/keras_hub/src/models/whisper/whisper_tokenizer_test.py
+++ b/keras_hub/src/models/whisper/whisper_tokenizer_test.py
@@ -6,22 +6,26 @@
class WhisperTokenizerTest(TestCase):
def setUp(self):
- self.vocab = ["!", "air", "Ġair", "plane", "Ġat", "port"]
- self.vocab += ["<|endoftext|>"]
- self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.merges = ["Ġ a", "Ġ t", "Ġ i", "Ġ b", "a i", "p l", "n e"]
self.merges += ["Ġa t", "p o", "r t", "Ġt h", "ai r", "pl a", "po rt"]
self.merges += ["Ġai r", "Ġa i", "pla ne"]
+ self.vocab = []
+ for merge in self.merges:
+ a, b = merge.split(" ")
+ self.vocab.extend([a, b, a + b])
+ self.vocab = sorted(set(self.vocab)) # Remove duplicates
+ self.vocab += ["!", "<|endoftext|>"]
+ self.vocab = dict([(token, i) for i, token in enumerate(self.vocab)])
self.special_tokens = {
- "<|startoftranscript|>": 9,
- "<|endoftext|>": 10,
- "<|notimestamps|>": 11,
- "<|transcribe|>": 12,
- "<|translate|>": 13,
+ "<|startoftranscript|>": 31, # len(self.vocab) == 31 at this point
+ "<|endoftext|>": 32,
+ "<|notimestamps|>": 33,
+ "<|transcribe|>": 34,
+ "<|translate|>": 35,
}
self.language_tokens = {
- "<|en|>": 14,
- "<|fr|>": 15,
+ "<|en|>": 36,
+ "<|fr|>": 37,
}
self.init_kwargs = {
"vocabulary": self.vocab,
@@ -39,17 +43,24 @@ def test_tokenizer_basics(self):
cls=WhisperTokenizer,
init_kwargs=self.init_kwargs,
input_data=self.input_data,
- expected_output=[[2, 3, 4, 2, 5, 10], [2, 3, 2, 5]],
+ expected_output=[
+ [23, 14, 24, 23, 16, 32],
+ [23, 14, 23, 16],
+ ],
+ expected_detokenize_output=[
+ " airplane at airport<|endoftext|>",
+ " airplane airport",
+ ],
)
def test_special_tokens(self):
tokenizer = WhisperTokenizer(**self.init_kwargs)
- self.assertEqual(tokenizer.bos_token_id, 9)
- self.assertEqual(tokenizer.eos_token_id, 10)
- self.assertEqual(tokenizer.pad_token_id, 10)
- self.assertEqual(tokenizer.no_timestamps_token_id, 11)
- self.assertEqual(tokenizer.translate_token_id, 13)
- self.assertEqual(tokenizer.transcribe_token_id, 12)
+ self.assertEqual(tokenizer.bos_token_id, 31)
+ self.assertEqual(tokenizer.eos_token_id, 32)
+ self.assertEqual(tokenizer.pad_token_id, 32)
+ self.assertEqual(tokenizer.no_timestamps_token_id, 33)
+ self.assertEqual(tokenizer.transcribe_token_id, 34)
+ self.assertEqual(tokenizer.translate_token_id, 35)
def test_errors_missing_special_tokens(self):
with self.assertRaises(ValueError):
diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer.py b/keras_hub/src/tokenizers/byte_pair_tokenizer.py
index bc9fc19f25..f63e86bc30 100644
--- a/keras_hub/src/tokenizers/byte_pair_tokenizer.py
+++ b/keras_hub/src/tokenizers/byte_pair_tokenizer.py
@@ -10,12 +10,19 @@
from typing import Iterable
import keras
+import numpy as np
import regex as re
+import tokenizers
from keras.src.saving import serialization_lib
+from tokenizers import decoders
+from tokenizers import models
+from tokenizers import pre_tokenizers
from keras_hub.src.api_export import keras_hub_export
from keras_hub.src.tokenizers import tokenizer
+from keras_hub.src.utils.tensor_utils import assert_tf_libs_installed
from keras_hub.src.utils.tensor_utils import convert_to_ragged_batch
+from keras_hub.src.utils.tensor_utils import in_tf_function
from keras_hub.src.utils.tensor_utils import is_int_dtype
from keras_hub.src.utils.tensor_utils import is_string_dtype
from keras_hub.src.utils.tensor_utils import preprocessing_function
@@ -52,6 +59,12 @@
# SPLIT_PATTERN_2 = rf"""[\s६{SPECIAL_WHITESPACES}]$"""
SPLIT_PATTERN_2 = rf"""[ \t\r\f\v६{SPECIAL_WHITESPACES}]$"""
+# From Llama3's tokenizer implementation.
+SPLIT_PATTERN_TOKENIZERS = (
+ "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| "
+ "?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
+)
+
def create_alts_for_unsplittable_tokens(unsplittable_tokens):
# Create alternates for all special tokens that will be not split during
@@ -249,30 +262,31 @@ class BytePairTokenizer(tokenizer.Tokenizer):
Examples:
Tokenize
- >>> vocab = {"butter": 1, "fly": 2}
>>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
+ >>> vocab = []
+ >>> [vocab.extend([a, b, a + b]) for a, b in [m.split(" ") for m in merge]]
+ >>> vocab = sorted(set(vocab)) # Remove duplicates
+ >>> vocab = dict([(token, i) for i, token in enumerate(vocab)])
>>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(vocab, merge)
>>> outputs = tokenizer("butterfly")
>>> np.array(outputs)
- array([1, 2], dtype=int32)
+ array([3, 8])
>>> seq1, seq2 = tokenizer(["butterfly", "butter"])
>>> np.array(seq1)
- array([1, 2])
+ array([3, 8])
>>> np.array(seq2)
- array([1])
+ array([3])
>>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(
... vocab, merge, sequence_length=2)
>>> seq1, seq2 = tokenizer(["butterfly", "butter"])
>>> np.array(seq1)
- array([1, 2], dtype=int32)
+ array([3, 8])
>>> np.array(seq2)
- array([1, 0], dtype=int32)
+ array([3, 0])
Detokenize
- >>> vocab = {"butter": 1, "fly": 2}
- >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
>>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(vocab, merge)
- >>> tokenizer.detokenize([[1, 2]])
+ >>> tokenizer.detokenize([[3, 8]])
['butterfly']
"""
@@ -292,7 +306,10 @@ def __init__(
f"Received: dtype={dtype}"
)
- super().__init__(dtype=dtype, **kwargs)
+ _allow_python_workflow = kwargs.pop("_allow_python_workflow", True)
+ super().__init__(
+ dtype=dtype, _allow_python_workflow=_allow_python_workflow, **kwargs
+ )
self.sequence_length = sequence_length
self.add_prefix_space = add_prefix_space
if unsplittable_tokens is None:
@@ -300,16 +317,6 @@ def __init__(
self.unsplittable_tokens = unsplittable_tokens
self.file_assets = [VOCAB_FILENAME, MERGES_FILENAME]
- # Create byte <=> unicode mapping. This is useful for handling
- # whitespace tokens.
- byte_list, unicode_list = bytes_to_unicode()
- self.byte2unicode = create_static_hashtable(
- byte_list, unicode_list, default=""
- )
- self.unicode2byte = create_static_hashtable(
- unicode_list, byte_list, default=""
- )
-
self.set_vocabulary_and_merges(vocabulary, merges)
def save_assets(self, dir_path):
@@ -326,17 +333,124 @@ def load_assets(self, dir_path):
merges_path = os.path.join(dir_path, MERGES_FILENAME)
self.set_vocabulary_and_merges(vocab_path, merges_path)
+ def _set_vocabulary_and_merges_tf(self, vocabulary, merges):
+ assert_tf_libs_installed(self.__class__.__name__)
+ self.vocabulary = vocabulary.copy()
+ self.merges = merges
+ for merge in merges:
+ if "#version:" in merge.lstrip():
+ continue
+ a, b = str(merge).split(" ")
+ if a not in vocabulary or b not in vocabulary:
+ raise ValueError(
+ f"Merge rule '{merge}' contains token '{a}' or '{b}' that "
+ "is not in the vocabulary."
+ )
+
+ # Create byte <=> unicode mapping. This is useful for handling
+ # whitespace tokens.
+ byte_list, unicode_list = bytes_to_unicode()
+ self.byte2unicode = create_static_hashtable(
+ byte_list, unicode_list, default=""
+ )
+ self.unicode2byte = create_static_hashtable(
+ unicode_list, byte_list, default=""
+ )
+
+ self.cache = BytePairTokenizerCache()
+ if self.unsplittable_tokens:
+ # Put special tokens into cache, so it won't be further split and
+ # merged.
+ self.cache.insert(
+ self.unsplittable_tokens, self.unsplittable_tokens
+ )
+
+ # Create mapping between string tokens to int ids, and vice versa.
+ byte_pairs = [x[0] for x in self.vocabulary.items()]
+ byte_pair_encoding_indices = [x[1] for x in self.vocabulary.items()]
+ self.token_to_id_map = create_static_hashtable(
+ byte_pairs,
+ byte_pair_encoding_indices,
+ default=-1,
+ )
+ self.id_to_token_map = create_static_hashtable(
+ byte_pair_encoding_indices,
+ byte_pairs,
+ default="",
+ )
+
+ # Create ranking of merge rules, this is the same as order of merge
+ # pairs in `self.merges`.
+ self.merge_ranks_lookup_default = len(self.merges) + 1
+ self.merge_ranks = create_static_hashtable(
+ self.merges,
+ list(range(len(self.merges))),
+ default=self.merge_ranks_lookup_default,
+ )
+
+ # Dummy attrs for serialization compatibility.
+ if not hasattr(self, "_tokenizer"):
+ self._tokenizer = None
+
+ def _set_vocabulary_and_merges_tokenizers(self, vocabulary, merges):
+ self.vocabulary = vocabulary.copy()
+ self.merges = merges
+ _merges = []
+ for merge in merges:
+ if "#version:" in merge.lstrip():
+ continue
+ a, b = str(merge).split(" ")
+ if a not in vocabulary or b not in vocabulary:
+ raise ValueError(
+ f"Merge rule '{merge}' contains token '{a}' or '{b}' that "
+ "is not in the vocabulary."
+ )
+ _merges.append((a, b))
+
+ self._tokenizer = tokenizers.Tokenizer(
+ models.BPE(vocab=vocabulary, merges=_merges)
+ )
+ if self.unsplittable_tokens:
+ self._tokenizer.add_special_tokens(self.unsplittable_tokens)
+ # Ensure the implementation matches Llama3's tokenizer behavior.
+ self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
+ [
+ pre_tokenizers.Split(
+ pattern=SPLIT_PATTERN_TOKENIZERS, behavior="isolated"
+ ),
+ pre_tokenizers.ByteLevel(
+ add_prefix_space=self.add_prefix_space, use_regex=False
+ ),
+ ]
+ )
+ self._tokenizer.decoder = decoders.ByteLevel()
+
+ # Dummy attrs for serialization compatibility.
+ if not hasattr(self, "cache"):
+ self.byte2unicode = None
+ self.unicode2byte = None
+ self.cache = None
+ self.id_to_token_map = None
+ self.token_to_id_map = None
+ self.merge_ranks_lookup_default = None
+ self.merge_ranks = None
+
def set_vocabulary_and_merges(self, vocabulary, merges):
"""Set the vocabulary and merge rules from data or files."""
if vocabulary is None or merges is None:
# Clear vocab related state.
self.vocabulary = None
self.merges = None
+ # _set_vocabulary_and_merges_tf
+ self.byte2unicode = None
+ self.unicode2byte = None
self.cache = None
self.id_to_token_map = None
self.token_to_id_map = None
self.merge_ranks_lookup_default = None
self.merge_ranks = None
+ # _set_vocabulary_and_merges_tokenizers
+ self._tokenizer = None
return
if isinstance(vocabulary, str):
@@ -352,9 +466,9 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
f"Vocabulary file: '{vocabulary}'"
)
with open(vocabulary, "r", encoding="utf-8") as f:
- self.vocabulary = json.load(f)
+ vocabulary = json.load(f)
elif isinstance(vocabulary, dict):
- self.vocabulary = vocabulary.copy()
+ vocabulary = vocabulary.copy()
else:
raise ValueError(
"Vocabulary must be an file path or dictionary mapping string "
@@ -374,46 +488,44 @@ def set_vocabulary_and_merges(self, vocabulary, merges):
f"Merges file: '{merges}'"
)
with open(merges, encoding="utf-8") as f:
- self.merges = [bp.rstrip() for bp in f]
+ merges = [bp.rstrip() for bp in f]
elif isinstance(merges, Iterable):
- self.merges = list(merges)
+ merges = list(merges)
else:
raise ValueError(
"Merges must be a file path or a list of merge rules. "
f"Received: `type(merges)={type(merges)}`"
)
- self.cache = BytePairTokenizerCache()
- if self.unsplittable_tokens:
- # Put special tokens into cache, so it won't be further split and
- # merged.
- self.cache.insert(
- self.unsplittable_tokens, self.unsplittable_tokens
+ # When using `BytePairTokenizer` with `tf.data`, it must be built
+ # outside the `tf.data` pipeline. So we always call
+ # `_set_vocabulary_and_merges_tf`.
+ try:
+ self._set_vocabulary_and_merges_tf(vocabulary, merges)
+ except ImportError:
+ pass
+ if self._allow_python_workflow:
+ self._set_vocabulary_and_merges_tokenizers(vocabulary, merges)
+
+ self._update_special_token_ids()
+
+ def _check_vocabulary(self):
+ if self.vocabulary is None:
+ raise ValueError(
+ "No vocabulary has been set for BytePairTokenizer. Make sure "
+ "to pass `vocabulary` and `merges` arguments when creating the "
+ "layer."
)
- # Create mapping between string tokens to int ids, and vice versa.
- byte_pairs = [x[0] for x in self.vocabulary.items()]
- byte_pair_encoding_indices = [x[1] for x in self.vocabulary.items()]
- self.token_to_id_map = create_static_hashtable(
- byte_pairs,
- byte_pair_encoding_indices,
- default=-1,
- )
- self.id_to_token_map = create_static_hashtable(
- byte_pair_encoding_indices,
- byte_pairs,
- default="",
- )
+ def _maybe_initialized_tf(self):
+ if getattr(self, "cache", None) is None:
+ self._set_vocabulary_and_merges_tf(self.vocabulary, self.merges)
- # Create ranking of merge rules, this is the same as order of merge
- # pairs in `self.merges`.
- self.merge_ranks_lookup_default = len(self.merges) + 1
- self.merge_ranks = create_static_hashtable(
- self.merges,
- list(range(len(self.merges))),
- default=self.merge_ranks_lookup_default,
- )
- self._update_special_token_ids()
+ def _maybe_initialized_tokenizers(self):
+ if getattr(self, "_tokenizer", None) is None:
+ self._set_vocabulary_and_merges_tokenizers(
+ self.vocabulary, self.merges
+ )
def get_vocabulary(self):
"""Get the tokenizer vocabulary as a list of strings tokens."""
@@ -425,25 +537,55 @@ def vocabulary_size(self):
self._check_vocabulary()
return len(self.vocabulary)
- def id_to_token(self, id):
- """Convert an integer id to a string token."""
+ def _id_to_token_tf(self, id):
+ self._maybe_initialized_tf()
# This will be slow, but keep memory usage down compared to building a
# dict. Assuming the main use case is looking up a few special tokens
# early in the vocab, this should be fine.
- self._check_vocabulary()
-
keys = self.get_vocabulary()
for token in keys:
if self.vocabulary[token] == id:
return token
raise ValueError(f"`id` is out of the vocabulary. Received: {id}")
+ def _id_to_token_tokenizers(self, id):
+ self._maybe_initialized_tokenizers()
+ try:
+ token = self._tokenizer.id_to_token(id)
+ except OverflowError:
+ token = None
+ if token is None:
+ raise ValueError(f"Id {id} is out of vocabulary range.")
+ return token
+
+ def id_to_token(self, id):
+ """Convert an integer id to a string token."""
+ self._check_vocabulary()
+ if not self._allow_python_workflow or in_tf_function():
+ return self._id_to_token_tf(id)
+ else:
+ return self._id_to_token_tokenizers(id)
+
+ def _token_to_id_tf(self, token):
+ self._maybe_initialized_tf()
+ return self.vocabulary[token]
+
+ def _token_to_id_tokenizers(self, token):
+ self._maybe_initialized_tokenizers()
+ token_id = self._tokenizer.token_to_id(token)
+ if token_id is None:
+ raise ValueError(f"Token '{token}' is not in the vocabulary.")
+ return token_id
+
def token_to_id(self, token):
"""Convert a string token to an integer id."""
self._check_vocabulary()
- return self.vocabulary[token]
+ if not self._allow_python_workflow or in_tf_function():
+ return self._token_to_id_tf(token)
+ else:
+ return self._token_to_id_tokenizers(token)
- def _bpe_merge_one_step(self, words, mask):
+ def _bpe_merge_one_step_tf(self, words, mask):
"""Perform one step of byte-pair merge."""
# Get all word pairs.
first, second = words[:, :-1], words[:, 1:]
@@ -524,7 +666,7 @@ def _bpe_merge_one_step(self, words, mask):
words = remove_strings_from_inputs(words, "")
return [words, mask]
- def _bpe_merge(self, inputs):
+ def _bpe_merge_tf(self, inputs):
"""Perform byte-pair merge for each word in the inputs."""
num_words = tf.shape(inputs)[0]
@@ -535,7 +677,7 @@ def loop_condition(_, mask):
initial_mask = tf.fill((num_words,), True)
merged_words, _ = tf.while_loop(
loop_condition,
- tf.function(self._bpe_merge_one_step),
+ tf.function(self._bpe_merge_one_step_tf),
loop_vars=[
inputs,
initial_mask,
@@ -547,17 +689,28 @@ def loop_condition(_, mask):
)
return merged_words
- def _check_vocabulary(self):
- if self.vocabulary is None:
- raise ValueError(
- "No vocabulary has been set for BytePairTokenizer. Make sure "
- "to pass `vocabulary` and `merges` arguments when creating the "
- "layer."
- )
+ def _bpe_merge_and_update_cache_tf(self, tokens):
+ """Process unseen tokens and add to cache."""
+
+ def _transform_bytes(tokens):
+ """Map token bytes to unicode using `byte2unicode`."""
+ split_bytes = tf.strings.bytes_split(tokens)
+ split_unicode = self.byte2unicode.lookup(split_bytes)
+ return split_unicode
+
+ words = _transform_bytes(tokens)
+ tokenized_words = self._bpe_merge_tf(words)
+
+ # For each word, join all its token by a whitespace,
+ # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
+ tokenized_words = tf.strings.reduce_join(
+ tokenized_words, axis=1, separator=" "
+ )
+ self.cache.insert(tokens, tokenized_words)
@preprocessing_function
- def tokenize(self, inputs):
- self._check_vocabulary()
+ def _tokenize_tf(self, inputs):
+ self._maybe_initialized_tf()
if self.add_prefix_space:
inputs = tf.strings.join([" ", inputs])
@@ -570,7 +723,6 @@ def tokenize(self, inputs):
"`tokenize()` inputs should be a string, list of strings, or "
f"string tensor with rank < 2. Received: {inputs}"
)
-
raw_tokens = split_strings_for_bpe(inputs, self.unsplittable_tokens)
token_row_splits = raw_tokens.row_splits
flat_tokens = raw_tokens.flat_values
@@ -578,14 +730,13 @@ def tokenize(self, inputs):
# Check cache.
cache_lookup = self.cache.lookup(flat_tokens)
cache_mask = cache_lookup == ""
-
has_unseen_words = tf.math.reduce_any(
(cache_lookup == "") & (flat_tokens != "")
)
def process_unseen_tokens():
unseen_tokens = tf.boolean_mask(flat_tokens, cache_mask)
- self._bpe_merge_and_update_cache(unseen_tokens)
+ self._bpe_merge_and_update_cache_tf(unseen_tokens)
return self.cache.lookup(flat_tokens)
# If `has_unseen_words == True`, it means not all tokens are in cache,
@@ -595,7 +746,6 @@ def process_unseen_tokens():
process_unseen_tokens,
lambda: cache_lookup,
)
-
tokens = tf.strings.split(tokenized_words, sep=" ")
if self.compute_dtype != tf.string:
# Encode merged tokens.
@@ -617,12 +767,71 @@ def process_unseen_tokens():
if unbatched:
tokens = tf.squeeze(tokens, 0)
tf.ensure_shape(tokens, shape=[self.sequence_length])
-
return tokens
- @preprocessing_function
- def detokenize(self, inputs):
+ def _tokenize_tokenizers(self, inputs):
+ self._maybe_initialized_tokenizers()
+
+ def _canonicalize_tokenize_inputs(inputs):
+ if isinstance(inputs, str):
+ return [inputs], False
+ elif isinstance(inputs, (tuple, list)):
+ if not all(isinstance(i, str) for i in inputs):
+ raise ValueError(
+ "If a list or tuple is provided as input, all elements "
+ "must be strings. "
+ f"Received: {inputs}"
+ )
+ return list(inputs), True
+ elif tf is not None and isinstance(inputs, tf.Tensor):
+ unbatched = inputs.shape.rank == 0
+ if unbatched:
+ inputs = tf.expand_dims(inputs, 0)
+ inputs = inputs.numpy().tolist()
+ inputs = keras.tree.map_structure(
+ lambda x: x.decode("utf-8"), inputs
+ )
+ return inputs, not unbatched
+ else:
+ raise ValueError(
+ "Input should be a string or a list of strings. "
+ f"Received: {inputs}"
+ )
+
+ inputs, batched = _canonicalize_tokenize_inputs(inputs)
+ outputs = self._tokenizer.encode_batch(inputs)
+ if is_int_dtype(self.compute_dtype):
+ batched_tokens = [o.ids for o in outputs]
+ else:
+ batched_tokens = [o.tokens for o in outputs]
+
+ # Convert to a dense output if `sequence_length` is set.
+ if self.sequence_length:
+ # Truncate sequences to `sequence_length`.
+ batched_tokens = [
+ tokens[: self.sequence_length] for tokens in batched_tokens
+ ]
+ # Pad sequences to `sequence_length`.
+ pad_token_id = getattr(self, "pad_token_id", 0)
+ batched_tokens = [
+ tokens + [pad_token_id] * (self.sequence_length - len(tokens))
+ for tokens in batched_tokens
+ ]
+
+ if not batched:
+ batched_tokens = batched_tokens[0]
+ return batched_tokens
+
+ def tokenize(self, inputs):
self._check_vocabulary()
+ if not self._allow_python_workflow or in_tf_function():
+ return self._tokenize_tf(inputs)
+ else:
+ return self._tokenize_tokenizers(inputs)
+
+ @preprocessing_function
+ def _detokenize_tf(self, inputs):
+ self._maybe_initialized_tf()
inputs, unbatched, rectangular = convert_to_ragged_batch(inputs)
inputs = tf.cast(inputs, self.dtype)
unicode_text = tf.strings.reduce_join(
@@ -637,28 +846,66 @@ def detokenize(self, inputs):
outputs = tf.squeeze(outputs, 0)
return outputs
- def compute_output_spec(self, input_spec):
- return keras.KerasTensor(
- input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
+ def _detokenize_tokenizers(self, inputs):
+ self._maybe_initialized_tokenizers()
+
+ def _canonicalize_detokenize_inputs(inputs):
+ is_batched = True
+ if isinstance(inputs, int):
+ inputs = [[inputs]]
+ is_batched = False
+ elif isinstance(inputs, (tuple, list)):
+ if not inputs or isinstance(inputs[0], int):
+ # Unbatched list of ints.
+ inputs = [list(inputs)]
+ is_batched = False
+ else:
+ # Batched list of lists of ints.
+ inputs = [list(seq) for seq in inputs]
+ elif isinstance(inputs, np.ndarray) or keras.ops.is_tensor(inputs):
+ inputs = keras.ops.convert_to_numpy(inputs)
+ if inputs.ndim == 0:
+ inputs = [[inputs.item()]]
+ is_batched = False
+ elif inputs.ndim == 1:
+ inputs = [inputs.tolist()]
+ is_batched = False
+ elif inputs.ndim == 2:
+ inputs = inputs.tolist()
+ else:
+ raise ValueError(
+ f"Array must be 0, 1 or 2 dimensional, "
+ f"got {inputs.shape}."
+ )
+ else:
+ raise ValueError(
+ "Input should be an integer, a list of integers, backend "
+ f"tensor or numpy array. Received: {inputs}"
+ )
+ return inputs, is_batched
+
+ inputs, batched = _canonicalize_detokenize_inputs(inputs)
+ outputs = self._tokenizer.decode_batch(
+ inputs, skip_special_tokens=False
)
+ if not batched:
+ outputs = outputs[0]
+ return outputs
- def _transform_bytes(self, tokens):
- """Map token bytes to unicode using `byte2unicode`."""
- split_bytes = tf.strings.bytes_split(tokens)
- split_unicode = self.byte2unicode.lookup(split_bytes)
- return split_unicode
+ def detokenize(self, inputs):
+ self._check_vocabulary()
+ if not self._allow_python_workflow or in_tf_function():
+ return self._detokenize_tf(inputs)
+ else:
+ return self._detokenize_tokenizers(inputs)
- def _bpe_merge_and_update_cache(self, tokens):
- """Process unseen tokens and add to cache."""
- words = self._transform_bytes(tokens)
- tokenized_words = self._bpe_merge(words)
+ def call(self, inputs, *args, training=None, **kwargs):
+ return self.tokenize(inputs, *args, **kwargs)
- # For each word, join all its token by a whitespace,
- # e.g., ["dragon", "fly"] => "dragon fly" for hash purpose.
- tokenized_words = tf.strings.reduce_join(
- tokenized_words, axis=1, separator=" "
+ def compute_output_spec(self, input_spec):
+ return keras.KerasTensor(
+ input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
)
- self.cache.insert(tokens, tokenized_words)
def get_config(self):
config = super().get_config()
diff --git a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py
index 985ce4b891..8c1a9d55d1 100644
--- a/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py
+++ b/keras_hub/src/tokenizers/byte_pair_tokenizer_test.py
@@ -195,3 +195,13 @@ def test_safe_mode_vocabulary_file_disallowed(self):
r"model archive.*Vocabulary file: .*vocab\.json",
):
tokenizer.set_vocabulary_and_merges(vocab_path, merges_path)
+
+
+class BytePairTokenizerDisallowPythonWorkflowTest(BytePairTokenizerTest):
+ def setUp(self):
+ super().setUp()
+ self.tokenizer = BytePairTokenizer(
+ vocabulary=VOCAB_PATH,
+ merges=MERGE_PATH,
+ _allow_python_workflow=False,
+ )
diff --git a/keras_hub/src/tokenizers/v2/byte_pair_tokenizer.py b/keras_hub/src/tokenizers/v2/byte_pair_tokenizer.py
deleted file mode 100644
index 5f35865db2..0000000000
--- a/keras_hub/src/tokenizers/v2/byte_pair_tokenizer.py
+++ /dev/null
@@ -1,361 +0,0 @@
-import json
-import os
-import warnings
-from typing import Iterable
-
-import keras
-import numpy as np
-import tokenizers
-from keras.src.saving import serialization_lib
-from tokenizers import decoders
-from tokenizers import models
-from tokenizers import pre_tokenizers
-
-from keras_hub.src.api_export import keras_hub_export
-from keras_hub.src.tokenizers import tokenizer
-from keras_hub.src.utils.tensor_utils import is_int_dtype
-from keras_hub.src.utils.tensor_utils import is_string_dtype
-
-VOCAB_FILENAME = "vocabulary.json"
-MERGES_FILENAME = "merges.txt"
-
-# From Llama3's tokenizer implementation.
-SPLIT_PATTERN = (
- "(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\\r\\n\\p{L}\\p{N}]?\\p{L}+|\\p{N}{1,3}| "
- "?[^\\s\\p{L}\\p{N}]+[\\r\\n]*|\\s*[\\r\\n]+|\\s+(?!\\S)|\\s+"
-)
-
-
-@keras_hub_export("keras_hub.tokenizers.v2.BytePairTokenizer")
-class BytePairTokenizer(tokenizer.Tokenizer):
- """Bype-pair encoding tokenizer layer.
-
- This BPE tokenizer provides the same functionality as the official GPT-2
- tokenizer. Given the same `vocabulary` which maps tokens to ids, and
- `merges` which describes BPE merge rules, it should provide the same output
- as OpenAI implementation (https://github.com/openai/gpt-2/blob/master/src/encoder.py).
- Different from OpenAI, this implementation is graph-compatible, so you can
- use it within a `tf.data` pipeline.
-
- If input is a batch of strings (rank > 0):
- By default, the layer will output a list of lists. If `sequence_length` is
- set, the layer will output a list of lists where all inputs have been padded
- or truncated to `sequence_length`.
- If input is a scalar string (rank == 0):
- By default, the layer will output a list with static shape. If
- `sequence_length` is set, the output will be a list of shape
- `[sequence_length]`.
-
- Args:
- vocabulary: string or dict, maps token to integer ids. If it is a
- string, it should be the file path to a json file.
- merges: string or list, contains the merge rule. If it is a string,
- it should be the file path to merge rules. The merge rule file
- should have one merge rule per line.
- sequence_length: int. If set, the output will be
- padded or truncated to the `sequence_length`. Defaults to `None`.
- add_prefix_space: bool. Whether to add an
- initial space to the input. This tokenizer is whitespace aware,
- and will tokenize a word with a leading space differently. Adding
- a prefix space to the first word will cause it to be tokenized
- equivalently to all subsequent words in the sequence.
- Defaults to `False`.
- unsplittable_tokens: list. A list of strings that will
- never be split during the word-level splitting applied before the
- byte-pair encoding. This can be used to ensure special tokens map to
- unique indices in the vocabulary, even if these special tokens
- contain splittable characters such as punctuation. Special tokens
- must still be included in `vocabulary`. Defaults to `None`.
-
- Examples:
-
- Tokenize
- >>> vocab = {"butter": 1, "fly": 2}
- >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
- >>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(vocab, merge)
- >>> outputs = tokenizer("butterfly")
- >>> np.array(outputs)
- array([1, 2], dtype=int32)
- >>> seq1, seq2 = tokenizer(["butterfly", "butter"])
- >>> np.array(seq1)
- array([1, 2])
- >>> np.array(seq2)
- array([1])
- >>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(
- ... vocab, merge, sequence_length=2)
- >>> seq1, seq2 = tokenizer(["butterfly", "butter"])
- >>> np.array(seq1)
- array([1, 2], dtype=int32)
- >>> np.array(seq2)
- array([1, 0], dtype=int32)
-
- Detokenize
- >>> vocab = {"butter": 1, "fly": 2}
- >>> merge = ["b u", "t t", "e r", "bu tt", "butt er", "f l", "fl y"]
- >>> tokenizer = keras_hub.tokenizers.BytePairTokenizer(vocab, merge)
- >>> tokenizer.detokenize([[1, 2]])
- ['butterfly']
- """
-
- def __init__(
- self,
- vocabulary=None,
- merges=None,
- sequence_length=None,
- add_prefix_space=False,
- unsplittable_tokens=None,
- dtype="int32",
- **kwargs,
- ):
- if not is_int_dtype(dtype) and not is_string_dtype(dtype):
- raise ValueError(
- "Output dtype must be an integer type or a string. "
- f"Received: dtype={dtype}"
- )
-
- super().__init__(dtype=dtype, **kwargs)
- self.sequence_length = sequence_length
- self.add_prefix_space = add_prefix_space
- if unsplittable_tokens is None:
- unsplittable_tokens = self.special_tokens
- self.unsplittable_tokens = unsplittable_tokens
- self.file_assets = [VOCAB_FILENAME, MERGES_FILENAME]
-
- self.set_vocabulary_and_merges(vocabulary, merges)
-
- def save_assets(self, dir_path):
- vocab_path = os.path.join(dir_path, VOCAB_FILENAME)
- merges_path = os.path.join(dir_path, MERGES_FILENAME)
- with open(vocab_path, "w", encoding="utf-8") as file:
- file.write(json.dumps(dict(self.vocabulary)))
- with open(merges_path, "w", encoding="utf-8") as file:
- for merge in self.merges:
- file.write(f"{merge}\n")
-
- def load_assets(self, dir_path):
- vocab_path = os.path.join(dir_path, VOCAB_FILENAME)
- merges_path = os.path.join(dir_path, MERGES_FILENAME)
- self.set_vocabulary_and_merges(vocab_path, merges_path)
-
- def set_vocabulary_and_merges(self, vocabulary, merges):
- """Set the vocabulary and merge rules from data or files."""
- if vocabulary is None or merges is None:
- # Clear vocab related state.
- self.vocabulary = None
- self.merges = None
- return
-
- if isinstance(vocabulary, str):
- if serialization_lib.in_safe_mode():
- raise ValueError(
- "Requested the loading of a vocabulary file outside of the "
- "model archive. This carries a potential risk of loading "
- "arbitrary and sensitive files and thus it is disallowed "
- "by default. If you trust the source of the artifact, you "
- "can override this error by passing `safe_mode=False` to "
- "the loading function, or calling "
- "`keras.config.enable_unsafe_deserialization()`. "
- f"Vocabulary file: '{vocabulary}'"
- )
- with open(vocabulary, "r", encoding="utf-8") as f:
- self.vocabulary = json.load(f)
- elif isinstance(vocabulary, dict):
- self.vocabulary = vocabulary.copy()
- else:
- raise ValueError(
- "Vocabulary must be an file path or dictionary mapping string "
- "token to int ids. Received: "
- f"`type(vocabulary)={type(vocabulary)}`."
- )
- if isinstance(merges, str):
- if serialization_lib.in_safe_mode():
- raise ValueError(
- "Requested the loading of a merges file outside of the "
- "model archive. This carries a potential risk of loading "
- "arbitrary and sensitive files and thus it is disallowed "
- "by default. If you trust the source of the artifact, you "
- "can override this error by passing `safe_mode=False` to "
- "the loading function, or calling "
- "`keras.config.enable_unsafe_deserialization()`. "
- f"Merges file: '{merges}'"
- )
- with open(merges, encoding="utf-8") as f:
- merges = [bp.rstrip() for bp in f]
- elif isinstance(merges, Iterable):
- merges = list(merges)
- else:
- raise ValueError(
- "Merges must be a file path or a list of merge rules. "
- f"Received: `type(merges)={type(merges)}`"
- )
- self.merges = merges
- _merges = []
- for merge in merges:
- a, b = merge.split(" ")
- if a not in self.vocabulary or b not in self.vocabulary:
- warnings.warn(
- f"Merge pair ({a}, {b}) contains a token not in the "
- "vocabulary. Skipping."
- )
- continue
- _merges.append((a, b))
-
- self._tokenizer = tokenizers.Tokenizer(
- models.BPE(vocab=self.vocabulary, merges=_merges)
- )
- if self.unsplittable_tokens:
- self._tokenizer.add_special_tokens(self.unsplittable_tokens)
- # Ensure the implementation matches Llama3's tokenizer behavior.
- self._tokenizer.pre_tokenizer = pre_tokenizers.Sequence(
- [
- pre_tokenizers.Split(
- pattern=SPLIT_PATTERN, behavior="isolated"
- ),
- pre_tokenizers.ByteLevel(
- add_prefix_space=self.add_prefix_space, use_regex=False
- ),
- ]
- )
- self._tokenizer.decoder = decoders.ByteLevel()
- self._update_special_token_ids()
-
- def get_vocabulary(self):
- """Get the tokenizer vocabulary as a list of strings tokens."""
- self._check_vocabulary()
- return self._tokenizer.get_vocab().keys()
-
- def vocabulary_size(self):
- """Get the integer size of the tokenizer vocabulary."""
- self._check_vocabulary()
- return self._tokenizer.get_vocab_size()
-
- def id_to_token(self, id):
- """Convert an integer id to a string token."""
- self._check_vocabulary()
- try:
- token = self._tokenizer.id_to_token(id)
- except OverflowError:
- token = None
- if token is None:
- raise ValueError(f"Id {id} is out of vocabulary range.")
- return token
-
- def token_to_id(self, token):
- """Convert a string token to an integer id."""
- self._check_vocabulary()
- token_id = self._tokenizer.token_to_id(token)
- if token_id is None:
- raise ValueError(f"Token '{token}' is not in the vocabulary.")
- return token_id
-
- def _check_vocabulary(self):
- if self.vocabulary is None:
- raise ValueError(
- "No vocabulary has been set for BytePairTokenizer. Make sure "
- "to pass `vocabulary` and `merges` arguments when creating the "
- "layer."
- )
-
- def _canonicalize_tokenize_inputs(self, inputs):
- if isinstance(inputs, str):
- return [inputs], False
- elif isinstance(inputs, (tuple, list)):
- if not all(isinstance(i, str) for i in inputs):
- raise ValueError(
- "If a list or tuple is provided as input, all elements "
- "must be strings. "
- f"Received: {inputs}"
- )
- return list(inputs), True
- else:
- raise ValueError(
- "Input should be a string or a list of strings. "
- f"Received: {inputs}"
- )
-
- def _canonicalize_detokenize_inputs(self, inputs):
- is_batched = True
- if isinstance(inputs, int):
- inputs = [[inputs]]
- is_batched = False
- elif isinstance(inputs, (tuple, list)):
- if not inputs or isinstance(inputs[0], int):
- # Unbatched list of ints.
- inputs = [list(inputs)]
- is_batched = False
- else:
- # Batched list of lists of ints.
- inputs = [list(seq) for seq in inputs]
- elif isinstance(inputs, np.ndarray) or keras.ops.is_tensor(inputs):
- inputs = keras.ops.convert_to_numpy(inputs)
- if inputs.ndim == 0:
- inputs = [[inputs.item()]]
- is_batched = False
- elif inputs.ndim == 1:
- inputs = [inputs.tolist()]
- is_batched = False
- elif inputs.ndim == 2:
- inputs = inputs.tolist()
- else:
- raise ValueError(
- f"Array must be 0, 1 or 2 dimensional, got {inputs.shape}."
- )
- else:
- raise ValueError(
- "Input should be an integer, a list of integers, backend "
- f"tensor or numpy array. Received: {inputs}"
- )
- return inputs, is_batched
-
- def tokenize(self, inputs):
- self._check_vocabulary()
- inputs, batched = self._canonicalize_tokenize_inputs(inputs)
- outputs = self._tokenizer.encode_batch(inputs)
- if is_int_dtype(self.compute_dtype):
- batched_tokens = [o.ids for o in outputs]
- else:
- batched_tokens = [o.tokens for o in outputs]
-
- # Convert to a dense output if `sequence_length` is set.
- if self.sequence_length:
- # Truncate sequences to `sequence_length`.
- batched_tokens = [
- tokens[: self.sequence_length] for tokens in batched_tokens
- ]
- # Pad sequences to `sequence_length`.
- pad_token_id = getattr(self, "pad_token_id", 0)
- batched_tokens = [
- tokens + [pad_token_id] * (self.sequence_length - len(tokens))
- for tokens in batched_tokens
- ]
-
- if not batched:
- batched_tokens = batched_tokens[0]
- return batched_tokens
-
- def detokenize(self, inputs):
- self._check_vocabulary()
- inputs, batched = self._canonicalize_detokenize_inputs(inputs)
- outputs = self._tokenizer.decode_batch(inputs)
- if not batched:
- outputs = outputs[0]
- return outputs
-
- def call(self, inputs, *args, training=None, **kwargs):
- return self.tokenize(inputs, *args, **kwargs)
-
- def compute_output_spec(self, input_spec):
- return keras.KerasTensor(
- input_spec.shape + (self.sequence_length,), dtype=self.compute_dtype
- )
-
- def get_config(self):
- config = super().get_config()
- config.update(
- {
- "sequence_length": self.sequence_length,
- "add_prefix_space": self.add_prefix_space,
- "unsplittable_tokens": self.unsplittable_tokens,
- }
- )
- return config
diff --git a/keras_hub/src/tokenizers/v2/byte_pair_tokenizer_test.py b/keras_hub/src/tokenizers/v2/byte_pair_tokenizer_test.py
deleted file mode 100644
index e8cfd5c6f8..0000000000
--- a/keras_hub/src/tokenizers/v2/byte_pair_tokenizer_test.py
+++ /dev/null
@@ -1,173 +0,0 @@
-import keras
-from keras.src.saving import serialization_lib
-
-from keras_hub.src.tests.test_case import TestCase
-from keras_hub.src.tokenizers.v2.byte_pair_tokenizer import BytePairTokenizer
-
-VOCAB_PATH = keras.utils.get_file(
- None,
- "https://storage.googleapis.com/keras-nlp/models/roberta_base/vocab.json",
-)
-MERGE_PATH = keras.utils.get_file(
- None,
- "https://storage.googleapis.com/keras-nlp/models/roberta_base/merges.txt",
-)
-
-
-class BytePairTokenizerTest(TestCase):
- def setUp(self):
- super().setUp()
- self.tokenizer = BytePairTokenizer(
- vocabulary=VOCAB_PATH, merges=MERGE_PATH
- )
-
- def test_tokenize_list_input(self):
- input_data = ["brown.", "black."]
- call_output = self.tokenizer(input_data)
- tokenize_output = self.tokenizer.tokenize(input_data)
- expected = [[31876, 4], [14178, 4]]
- self.assertAllEqual(call_output, expected)
- self.assertAllEqual(tokenize_output, expected)
-
- def test_tokenize_string_output(self):
- input_data = ["quick brown fox.", "slow black bear."]
- tokenizer = BytePairTokenizer(
- vocabulary=VOCAB_PATH, merges=MERGE_PATH, dtype="string"
- )
- call_output = tokenizer(input_data)
- expected = [
- ["quick", "Ġbrown", "Ġfox", "."],
- ["slow", "Ġblack", "Ġbear", "."],
- ]
- self.assertAllEqual(call_output, expected)
-
- def test_tokenize_with_special_tokens(self):
- vocab = {"sp": 0, "s": 1, "p": 2}
- merges = ["s p"]
- tokenizer = BytePairTokenizer(
- vocabulary=vocab,
- merges=merges,
- unsplittable_tokens=["s", "p"],
- )
- output = tokenizer("sp")
- self.assertAllEqual(output, [1, 2])
-
- # If not setting special tokens, "sp" is one token.
- tokenizer = BytePairTokenizer(
- vocabulary=vocab,
- merges=merges,
- )
- output = tokenizer("sp")
- self.assertAllEqual(output, [0])
-
- def test_tokenize_prefix_space(self):
- input_data = ["brown.", "black."]
- tokenizer = BytePairTokenizer(
- vocabulary=VOCAB_PATH,
- merges=MERGE_PATH,
- dtype="string",
- add_prefix_space=True,
- )
- call_output = tokenizer(input_data)
-
- expected = [["Ġbrown", "."], ["Ġblack", "."]]
- self.assertAllEqual(call_output, expected)
-
- def test_tokenize_scalar_input(self):
- input_data = "brown."
- encoded = self.tokenizer.tokenize(input_data)
- self.assertAllEqual(encoded, [31876, 4])
-
- def test_detokenize_scalar_input(self):
- input_data = ["quick brown fox."]
- encoded = self.tokenizer.tokenize(input_data)
- decoded = self.tokenizer.detokenize(encoded)
- self.assertAllEqual(input_data, decoded)
-
- def test_detokenize_list_input(self):
- input_data = ["quick brown fox.", "slow bear"]
- encoded = self.tokenizer.tokenize(input_data)
- decoded = self.tokenizer.detokenize(encoded)
- self.assertAllEqual(input_data, decoded)
-
- def test_error_id_out_of_vocabulary(self):
- with self.assertRaises(ValueError):
- self.tokenizer.id_to_token(self.tokenizer.vocabulary_size())
- with self.assertRaises(ValueError):
- self.tokenizer.id_to_token(-1)
-
- def test_whitespace_split(self):
- input_data = "\n\n\n s"
- encoded = self.tokenizer(input_data)
- self.assertAllEqual(encoded, [50140, 50118, 1437, 579])
-
- input_data = " \n\n\ns"
- encoded = self.tokenizer(input_data)
- self.assertAllEqual(encoded, [1437, 1437, 50140, 50118, 29])
-
- # This is important for Llama3 which uses the \n\n sequence in chat
- # templates: \n\n must be tokenized as a single token
- input_data = "Hello\n\nHello"
- encoded = self.tokenizer(input_data)
- self.assertAllEqual(encoded, [31414, 50140, 31414])
-
- input_data = "Hello\n\n\n\nHello"
- encoded = self.tokenizer(input_data)
- self.assertAllEqual(encoded, [31414, 50140, 50140, 31414])
-
- input_data = "Hello\n\n"
- encoded = self.tokenizer(input_data)
- self.assertAllEqual(encoded, [31414, 50140])
-
- input_data = "Hello\n\n\n\n"
- encoded = self.tokenizer(input_data)
- self.assertAllEqual(encoded, [31414, 50140, 50140])
-
- def test_special_whitespace(self):
- input_data = "\xa0 \xa0 \x3000 s"
- encoded = self.tokenizer(input_data)
- self.assertAllEqual(encoded, [50141, 50143, 12096, 579])
-
- def test_cjk_input(self):
- input_data = "素晴らしい!芭比Q啦~"
- # Black formats long list by one element per line, which is bad to read.
- expected = [36714, 20024, 21402, 37127, 27, 20024, 48945, 47918]
- expected += [47780, 43251, 4394, 10172, 36484, 27969, 12410, 37127]
- expected += [10965, 10674, 1864, 42393, 15722, 18164, 43251, 10809]
- expected += [17772]
- encoded = self.tokenizer(input_data)
- self.assertAllEqual(encoded, expected)
-
- def test_config(self):
- input_data = ["the quick brown whale."]
- cloned_tokenizer = BytePairTokenizer.from_config(
- self.tokenizer.get_config()
- )
- cloned_tokenizer.set_vocabulary_and_merges(
- self.tokenizer.vocabulary, self.tokenizer.merges
- )
- self.assertAllEqual(
- self.tokenizer(input_data),
- cloned_tokenizer(input_data),
- )
-
- def test_safe_mode_vocabulary_file_disallowed(self):
- import os
-
- temp_dir = self.get_temp_dir()
- vocab_path = os.path.join(temp_dir, "vocab.json")
- merges_path = os.path.join(temp_dir, "merges.txt")
-
- with open(vocab_path, "w") as file:
- file.write('{"<|endoftext|>": 0, "the": 1, "quick": 2}')
- with open(merges_path, "w") as file:
- file.write("t h\nthe quick")
-
- tokenizer = BytePairTokenizer()
- with serialization_lib.SafeModeScope(True):
- with self.assertRaisesRegex(
- ValueError,
- r"Requested the loading of a vocabulary file outside of the "
- r"model archive.*Vocabulary file: .*vocab\.json",
- ):
- tokenizer.set_vocabulary_and_merges(vocab_path, merges_path)
diff --git a/keras_hub/src/utils/tensor_utils.py b/keras_hub/src/utils/tensor_utils.py
index 26c25202b6..cacf0696c8 100644
--- a/keras_hub/src/utils/tensor_utils.py
+++ b/keras_hub/src/utils/tensor_utils.py
@@ -215,6 +215,46 @@ def convert(x):
return keras.tree.map_structure(convert, x)
+def convert_preprocessing_outputs_python(x):
+ """Convert outputs after preprocessing to a backend agnostic format.
+
+ This function is used to convert `tf.Tensor` and `tf.RaggedTensor` output
+ from preprocessing layers to either:
+
+ - The correct tensor type for the Keras backend framework.
+ - Python lists, in the case of string data.
+
+ Examples:
+ ```python
+ # A batch of three samples each with two string segments.
+ x = (["hi", "yo", "hey"], ["bye", "ciao", ""])
+ keras_hub.utils.convert_preprocessing_outputs_python(x)
+
+ # A batch of features in a dictionary.
+ x = {
+ "text": ["hi", "hello", "hey"],
+ "images": np.ones((3, 64, 64, 3)),
+ "labels": [1, 0, 1],
+ }
+ keras_hub.utils.convert_preprocessing_outputs_python(x)
+ ```
+ """
+ if in_no_convert_scope():
+ return x
+
+ def convert(x):
+ if x is None:
+ return x
+ if isinstance(x, str):
+ return tensor_to_list(x)
+ dtype = None
+ if hasattr(x, "dtype"):
+ dtype = keras.backend.standardize_dtype(x.dtype)
+ return ops.convert_to_tensor(x, dtype=dtype)
+
+ return keras.tree.map_structure(convert, x)
+
+
def _decode_strings_to_utf8(inputs):
"""Recursively decodes to list of strings with 'utf-8' encoding."""
if isinstance(inputs, bytes):
diff --git a/keras_hub/src/utils/transformers/export/gpt2_test.py b/keras_hub/src/utils/transformers/export/gpt2_test.py
index 563e39c38d..b557189080 100644
--- a/keras_hub/src/utils/transformers/export/gpt2_test.py
+++ b/keras_hub/src/utils/transformers/export/gpt2_test.py
@@ -34,6 +34,9 @@ def test_export_to_hf(self):
"i": 8,
"c": 9,
"k": 10,
+ "Ġq": 11,
+ "ui": 12,
+ "ck": 13,
}
merges = ["Ġ q", "u i", "c k"]
diff --git a/keras_hub/src/utils/transformers/export/qwen_test.py b/keras_hub/src/utils/transformers/export/qwen_test.py
index 60acc54123..78feb2271c 100644
--- a/keras_hub/src/utils/transformers/export/qwen_test.py
+++ b/keras_hub/src/utils/transformers/export/qwen_test.py
@@ -38,9 +38,12 @@ def test_export_to_hf(self):
"c": 11,
"k": 12,
" ": 13, # Space
+ "qu": 14,
+ "ic": 15,
+ "ck": 16,
}
# Add a dummy merge to satisfy initialization
- merges = ["q u", "i c", "k"]
+ merges = ["q u", "i c", "c k"]
temp_dir = self.get_temp_dir()
vocab_path = os.path.join(temp_dir, "vocab.json")