Skip to content

Commit 9492201

Browse files
committed
fix: CI testing failures
1 parent 569159d commit 9492201

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

keras/src/layers/preprocessing/text_vectorization.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,9 +6,6 @@
66
from keras.src.layers.preprocessing.index_lookup import listify_tensors
77
from keras.src.layers.preprocessing.string_lookup import StringLookup
88
from keras.src.saving import serialization_lib
9-
from keras.src.trainers.data_adapters.grain_dataset_adapter import (
10-
GrainDatasetAdapter,
11-
)
129
from keras.src.utils import argument_validation
1310
from keras.src.utils import backend_utils
1411
from keras.src.utils import tf_utils
@@ -436,12 +433,12 @@ def adapt(self, data, batch_size=None, steps=None):
436433
elif grain.available and isinstance(
437434
data, (grain.MapDataset, grain.IterDataset, grain.DataLoader)
438435
):
439-
dataset_adapter = GrainDatasetAdapter(data)
440-
tf_dataset = dataset_adapter.get_tf_dataset()
441-
if steps is not None:
442-
tf_dataset = tf_dataset.take(steps)
443-
for batch in tf_dataset:
436+
step = 0
437+
for batch in data:
438+
if steps is not None and step >= steps:
439+
break
444440
self.update_state(_extract_adapt_batch(batch))
441+
step += 1
445442
else:
446443
data = tf_utils.ensure_tensor(data, dtype="string")
447444
if data.shape.rank == 1:

keras/src/layers/preprocessing/text_vectorization_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,10 @@ def test_adapt_with_steps(self):
493493
self.assertIn("bar", vocab)
494494
self.assertNotIn("unique_word", vocab)
495495

496+
@pytest.mark.skipif(
497+
backend.backend() != "tensorflow",
498+
reason="TextVectorization and Grain adapt path require TensorFlow",
499+
)
496500
def test_adapt_with_grain_dataset(self):
497501
pytest.importorskip("grain")
498502
import grain as grain_module
@@ -521,6 +525,10 @@ def __getitem__(self, index):
521525
self.assertIn("bar", vocab)
522526
self.assertIn("baz", vocab)
523527

528+
@pytest.mark.skipif(
529+
backend.backend() != "tensorflow",
530+
reason="TextVectorization and Grain adapt path require TensorFlow",
531+
)
524532
def test_adapt_with_grain_dataset_and_steps(self):
525533
pytest.importorskip("grain")
526534
import grain as grain_module

0 commit comments

Comments
 (0)