From af584b475079cbd674ee818158fe8fa8096e954e Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 17 Apr 2025 10:26:40 +0530 Subject: [PATCH 1/3] Update gemma3_causal_lm_preprocessor.py Added checks for invalid inputs --- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py index f7372a9cbd..4815efe1fd 100644 --- a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +++ b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py @@ -512,6 +512,10 @@ def call( # Extract text part of the input. prompts, responses = x["prompts"], x["responses"] + tf.debugging.assert_shapes([ + (prompts,('N',)), + (responses,('N',)) + ]) # Find out if the input is batched/not batched. Uprank if not batched. # In other preprocessors, we don't have to do this, but here, all From dc4ae8c7fb068baf6b397c2765aa55586a4b1fb1 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 17 Apr 2025 10:39:10 +0530 Subject: [PATCH 2/3] Update gemma3_causal_lm_preprocessor.py --- keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py index 4815efe1fd..a60d095a2d 100644 --- a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py +++ b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor.py @@ -512,10 +512,7 @@ def call( # Extract text part of the input. prompts, responses = x["prompts"], x["responses"] - tf.debugging.assert_shapes([ - (prompts,('N',)), - (responses,('N',)) - ]) + tf.debugging.assert_shapes([(prompts, ("N",)), (responses, ("N",))]) # Find out if the input is batched/not batched. Uprank if not batched. # In other preprocessors, we don't have to do this, but here, all From 07c5c7792b18b9793d971d915ad558d971daf5f7 Mon Sep 17 00:00:00 2001 From: RAHUL KUMAR Date: Thu, 17 Apr 2025 11:28:22 +0530 Subject: [PATCH 3/3] Update gemma3_causal_lm_preprocessor_test.py Added tests to check invalid inputs --- .../gemma3/gemma3_causal_lm_preprocessor_test.py | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py index 210da7d24f..17ee4dab0f 100644 --- a/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py +++ b/keras_hub/src/models/gemma3/gemma3_causal_lm_preprocessor_test.py @@ -167,6 +167,20 @@ def test_generate_postprocess(self): x = preprocessor.generate_postprocess(input_data) self.assertAllEqual(x, "the quick brown fox \n\n ") + def test_invalid_shape(self): + with self.assertRaises(ValueError): + input_data = { + "prompts": ["hello world", "this is testing"], + "responses": [""], + } + self.text_preprocessor(input_data) + with self.assertRaises(ValueError): + input_data = { + "prompts": ["hello world", "this is testing"], + "responses": ["hello", "", ""], + } + self.text_preprocessor(input_data) + @pytest.mark.kaggle_key_required @pytest.mark.extra_large def test_all_presets(self):