Skip to content

Commit 24e4308

Browse files
Fixes compute_output_shape for PaliGemmaVitEncoder and Gemma3VisionEncoderBlock (#2210)
1 parent 13788a9 commit 24e4308

File tree

2 files changed

+2
-2
lines changed

2 files changed

+2
-2
lines changed

keras_hub/src/models/gemma3/gemma3_vision_encoder.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -488,7 +488,7 @@ def compute_output_shape(self, inputs_shape):
488488
# Fix the compatibility issue with Keras 3.1 where
489489
# `compute_output_spec` fails to propagate `inputs_shape`
490490
# correctly, causing it to be `None`.
491-
inputs_shape = [None, None, None]
491+
return [None, None, self.hidden_dim]
492492
return [
493493
None,
494494
(inputs_shape[2] // self.patch_size) ** 2,

keras_hub/src/models/pali_gemma/pali_gemma_vit.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,7 @@ def compute_output_shape(self, inputs_shape):
329329
# Fix the compatibility issue with Keras 3.1 where
330330
# `compute_output_spec` fails to propagate `inputs_shape`
331331
# correctly, causing it to be `None`.
332-
inputs_shape = [None, None, None]
332+
return [None, None, self.hidden_dim]
333333
return [
334334
inputs_shape[0],
335335
(inputs_shape[1] // self.patch_size) ** 2,

0 commit comments

Comments
 (0)