|
6 | 6 | from keras.src.layers.preprocessing.index_lookup import listify_tensors |
7 | 7 | from keras.src.layers.preprocessing.string_lookup import StringLookup |
8 | 8 | from keras.src.saving import serialization_lib |
| 9 | +from keras.src.trainers.data_adapters.grain_dataset_adapter import ( |
| 10 | + GrainDatasetAdapter, |
| 11 | +) |
9 | 12 | from keras.src.utils import argument_validation |
10 | 13 | from keras.src.utils import backend_utils |
11 | 14 | from keras.src.utils import tf_utils |
| 15 | +from keras.src.utils.module_utils import grain |
12 | 16 | from keras.src.utils.module_utils import tensorflow as tf |
13 | 17 |
|
14 | 18 |
|
| 19 | +def _extract_adapt_batch(batch): |
| 20 | + """Extract text input from a batch; handle (x,) or (x, y) or (x, y, w).""" |
| 21 | + if isinstance(batch, (tuple, list)) and len(batch) > 0: |
| 22 | + return batch[0] |
| 23 | + return batch |
| 24 | + |
| 25 | + |
15 | 26 | @keras_export("keras.layers.TextVectorization") |
16 | 27 | class TextVectorization(Layer): |
17 | 28 | """A preprocessing layer which maps text features to integer sequences. |
@@ -403,22 +414,34 @@ def adapt(self, data, batch_size=None, steps=None): |
403 | 414 |
|
404 | 415 | Arguments: |
405 | 416 | data: The data to train on. It can be passed either as a |
406 | | - batched `tf.data.Dataset`, as a list of strings, |
407 | | - or as a NumPy array. |
| 417 | + batched `tf.data.Dataset`, a Grain dataset |
| 418 | + (`grain.MapDataset`, `grain.IterDataset`, or |
| 419 | + `grain.DataLoader`), a list of strings, or a NumPy array. |
| 420 | + For dataset inputs, each batch may be just the text tensor |
| 421 | + or a tuple `(text, labels)` (only the text is used). |
408 | 422 | steps: Integer or `None`. |
409 | 423 | Total number of steps (batches of samples) to process. |
410 | | - If `data` is a `tf.data.Dataset`, and `steps` is `None`, |
411 | | - `adapt()` will run until the input dataset is exhausted. |
412 | | - When passing an infinitely |
413 | | - repeating dataset, you must specify the `steps` argument. This |
| 424 | + If `data` is a `tf.data.Dataset` or a Grain dataset, and |
| 425 | + `steps` is `None`, `adapt()` will run until the input |
| 426 | + dataset is exhausted. When passing an infinitely repeating |
| 427 | + dataset, you must specify the `steps` argument. This |
414 | 428 | argument is not supported with array inputs or list inputs. |
415 | 429 | """ |
416 | 430 | self.reset_state() |
417 | 431 | if isinstance(data, tf.data.Dataset): |
418 | 432 | if steps is not None: |
419 | 433 | data = data.take(steps) |
420 | 434 | for batch in data: |
421 | | - self.update_state(batch) |
| 435 | + self.update_state(_extract_adapt_batch(batch)) |
| 436 | + elif grain.available and isinstance( |
| 437 | + data, (grain.MapDataset, grain.IterDataset, grain.DataLoader) |
| 438 | + ): |
| 439 | + dataset_adapter = GrainDatasetAdapter(data) |
| 440 | + tf_dataset = dataset_adapter.get_tf_dataset() |
| 441 | + if steps is not None: |
| 442 | + tf_dataset = tf_dataset.take(steps) |
| 443 | + for batch in tf_dataset: |
| 444 | + self.update_state(_extract_adapt_batch(batch)) |
422 | 445 | else: |
423 | 446 | data = tf_utils.ensure_tensor(data, dtype="string") |
424 | 447 | if data.shape.rank == 1: |
|
0 commit comments