diff --git a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py index f7372a9cbd..a60d095a2d 100644 --- a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +++ b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py @@ -512,6 +512,7 @@ def call( # Extract text part of the input. prompts, responses = x["prompts"], x["responses"] + tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))]) # Find out if the input is batched/not batched. Uprank if not batched. # In other preprocessors, we don't have to do this, but here, all diff --git a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py index 210da7d24f..17ee4dab0f 100644 --- a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py +++ b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py @@ -167,6 +167,20 @@ def test_generate_postprocess(self): x = preprocessor.generate_postprocess(input_data) self.assertAllEqual(x, "the quick brown fox \n\n ") + def test_invalid_shape(self): + with self.assertRaises(ValueError): + input_data = { + "prompts": ["hello world", "this is testing"], + "responses": [""], + } + self.text_preprocessor(input_data) + with self.assertRaises(ValueError): + input_data = { + "prompts": ["hello world", "this is testing"], + "responses": ["hello", "", ""], + } + self.text_preprocessor(input_data) + @pytest.mark.kaggle_key_required @pytest.mark.extra_large def test_all_presets(self):