Skip to content

Commit 569159d

Browse files
committed
feat: Add Grain dataset support for TextVectorization.adapt()
1 parent d9966a5 commit 569159d

File tree

2 files changed

+85
-7
lines changed

2 files changed

+85
-7
lines changed

keras/src/layers/preprocessing/text_vectorization.py

Lines changed: 30 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,23 @@
66
from keras.src.layers.preprocessing.index_lookup import listify_tensors
77
from keras.src.layers.preprocessing.string_lookup import StringLookup
88
from keras.src.saving import serialization_lib
9+
from keras.src.trainers.data_adapters.grain_dataset_adapter import (
10+
GrainDatasetAdapter,
11+
)
912
from keras.src.utils import argument_validation
1013
from keras.src.utils import backend_utils
1114
from keras.src.utils import tf_utils
15+
from keras.src.utils.module_utils import grain
1216
from keras.src.utils.module_utils import tensorflow as tf
1317

1418

19+
def _extract_adapt_batch(batch):
20+
"""Extract text input from a batch; handle (x,) or (x, y) or (x, y, w)."""
21+
if isinstance(batch, (tuple, list)) and len(batch) > 0:
22+
return batch[0]
23+
return batch
24+
25+
1526
@keras_export("keras.layers.TextVectorization")
1627
class TextVectorization(Layer):
1728
"""A preprocessing layer which maps text features to integer sequences.
@@ -403,22 +414,34 @@ def adapt(self, data, batch_size=None, steps=None):
403414
404415
Arguments:
405416
data: The data to train on. It can be passed either as a
406-
batched `tf.data.Dataset`, as a list of strings,
407-
or as a NumPy array.
417+
batched `tf.data.Dataset`, a Grain dataset
418+
(`grain.MapDataset`, `grain.IterDataset`, or
419+
`grain.DataLoader`), a list of strings, or a NumPy array.
420+
For dataset inputs, each batch may be just the text tensor
421+
or a tuple `(text, labels)` (only the text is used).
408422
steps: Integer or `None`.
409423
Total number of steps (batches of samples) to process.
410-
If `data` is a `tf.data.Dataset`, and `steps` is `None`,
411-
`adapt()` will run until the input dataset is exhausted.
412-
When passing an infinitely
413-
repeating dataset, you must specify the `steps` argument. This
424+
If `data` is a `tf.data.Dataset` or a Grain dataset, and
425+
`steps` is `None`, `adapt()` will run until the input
426+
dataset is exhausted. When passing an infinitely repeating
427+
dataset, you must specify the `steps` argument. This
414428
argument is not supported with array inputs or list inputs.
415429
"""
416430
self.reset_state()
417431
if isinstance(data, tf.data.Dataset):
418432
if steps is not None:
419433
data = data.take(steps)
420434
for batch in data:
421-
self.update_state(batch)
435+
self.update_state(_extract_adapt_batch(batch))
436+
elif grain.available and isinstance(
437+
data, (grain.MapDataset, grain.IterDataset, grain.DataLoader)
438+
):
439+
dataset_adapter = GrainDatasetAdapter(data)
440+
tf_dataset = dataset_adapter.get_tf_dataset()
441+
if steps is not None:
442+
tf_dataset = tf_dataset.take(steps)
443+
for batch in tf_dataset:
444+
self.update_state(_extract_adapt_batch(batch))
422445
else:
423446
data = tf_utils.ensure_tensor(data, dtype="string")
424447
if data.shape.rank == 1:

keras/src/layers/preprocessing/text_vectorization_test.py

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,61 @@ def test_adapt_with_steps(self):
493493
self.assertIn("bar", vocab)
494494
self.assertNotIn("unique_word", vocab)
495495

496+
def test_adapt_with_grain_dataset(self):
497+
pytest.importorskip("grain")
498+
import grain as grain_module
499+
500+
class TextSource(grain_module.sources.RandomAccessDataSource):
501+
def __init__(self, texts):
502+
self.texts = np.asarray(texts, dtype=object)
503+
504+
def __len__(self):
505+
return len(self.texts)
506+
507+
def __getitem__(self, index):
508+
return self.texts[index]
509+
510+
texts = ["foo bar", "bar baz", "baz foo"]
511+
source = TextSource(texts)
512+
dataset = (
513+
grain_module.MapDataset.source(source)
514+
.to_iter_dataset()
515+
.batch(batch_size=2)
516+
)
517+
layer = layers.TextVectorization(output_mode="int")
518+
layer.adapt(dataset)
519+
vocab = layer.get_vocabulary()
520+
self.assertIn("foo", vocab)
521+
self.assertIn("bar", vocab)
522+
self.assertIn("baz", vocab)
523+
524+
def test_adapt_with_grain_dataset_and_steps(self):
525+
pytest.importorskip("grain")
526+
import grain as grain_module
527+
528+
class TextSource(grain_module.sources.RandomAccessDataSource):
529+
def __init__(self, texts):
530+
self.texts = np.asarray(texts, dtype=object)
531+
532+
def __len__(self):
533+
return len(self.texts)
534+
535+
def __getitem__(self, index):
536+
return self.texts[index]
537+
538+
texts = ["foo bar", "bar baz", "unique_word"]
539+
source = TextSource(texts)
540+
dataset = (
541+
grain_module.MapDataset.source(source)
542+
.to_iter_dataset()
543+
.batch(batch_size=1)
544+
)
545+
layer = layers.TextVectorization(output_mode="int")
546+
layer.adapt(dataset, steps=2)
547+
vocab = layer.get_vocabulary()
548+
self.assertIn("bar", vocab)
549+
self.assertNotIn("unique_word", vocab)
550+
496551
def test_invalid_ngrams(self):
497552
with self.assertRaises(ValueError):
498553
layers.TextVectorization(ngrams="invalid")

0 commit comments

Comments
 (0)