|
7 | 7 | from datasets import load_dataset |
8 | 8 | from torch.utils.data import DataLoader |
9 | 9 | from transformers import PreTrainedTokenizerFast |
| 10 | +from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast |
| 11 | +from transformers.models.siglip.configuration_siglip import SiglipVisionConfig |
10 | 12 |
|
| 13 | +from liger_kernel.transformers import apply_liger_kernel_to_gemma3 |
11 | 14 | from liger_kernel.transformers import apply_liger_kernel_to_mllama |
12 | 15 | from liger_kernel.transformers import apply_liger_kernel_to_paligemma |
13 | 16 | from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl |
|
18 | 21 | from test.utils import assert_verbose_allclose |
19 | 22 | from test.utils import load_tokenizer_config |
20 | 23 | from test.utils import multimodal_collate_fn |
| 24 | +from test.utils import revert_liger_kernel_to_gemma3 |
21 | 25 | from test.utils import revert_liger_kernel_to_mllama |
22 | 26 | from test.utils import revert_liger_kernel_to_Paligemma |
23 | 27 | from test.utils import revert_liger_kernel_to_qwen2_5_vl |
|
80 | 84 | except ImportError: |
81 | 85 | PALIGEMMA_AVAILABLE = False |
82 | 86 |
|
| 87 | +try: |
| 88 | + # Gemma3 is only available in transformers>=4.50.0 |
| 89 | + from transformers.models.gemma3.configuration_gemma3 import Gemma3Config |
| 90 | + from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig |
| 91 | + from transformers.models.gemma3.image_processing_gemma3 import Gemma3ImageProcessor |
| 92 | + from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration |
| 93 | + from transformers.models.gemma3.processing_gemma3 import Gemma3Processor |
| 94 | + |
| 95 | + GEMMA3_AVAILABLE = True |
| 96 | +except ImportError: |
| 97 | + GEMMA3_AVAILABLE = False |
| 98 | + |
83 | 99 | from liger_kernel.utils import infer_device |
84 | 100 |
|
85 | 101 | device = infer_device() |
|
254 | 270 | ), |
255 | 271 | ) |
256 | 272 |
|
| 273 | +if GEMMA3_AVAILABLE: |
| 274 | + MINI_MODEL_SETUPS["mini_gemma3"] = MiniModelConfig( |
| 275 | + liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_gemma3, fused_linear_cross_entropy=False), |
| 276 | + liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3, |
| 277 | + model_class=Gemma3ForConditionalGeneration, |
| 278 | + mini_model_config=Gemma3Config( |
| 279 | + vision_config=SiglipVisionConfig( |
| 280 | + attention_dropout=0.0, |
| 281 | + hidden_act="gelu_pytorch_tanh", |
| 282 | + hidden_size=1152, |
| 283 | + image_size=224, |
| 284 | + intermediate_size=2048, # 4304 |
| 285 | + layer_norm_eps=1e-06, |
| 286 | + num_attention_heads=4, # 16 |
| 287 | + num_channels=3, |
| 288 | + num_hidden_layers=4, # 27 |
| 289 | + num_image_tokens=256, |
| 290 | + num_positions=256, |
| 291 | + patch_size=14, |
| 292 | + projection_dim=1024, # 2304 |
| 293 | + ).to_dict(), |
| 294 | + text_config=Gemma3TextConfig( |
| 295 | + vocab_size=32000, # 256000 |
| 296 | + hidden_size=1024, # 3072 |
| 297 | + intermediate_size=2048, # 24576 |
| 298 | + num_hidden_layers=4, # 28 |
| 299 | + num_attention_heads=4, # 16 |
| 300 | + num_key_value_heads=4, # 16 |
| 301 | + head_dim=256, |
| 302 | + hidden_activation="gelu_pytorch_tanh", |
| 303 | + max_position_embeddings=8192, |
| 304 | + initializer_range=0.02, |
| 305 | + rms_norm_eps=1e-06, |
| 306 | + use_cache=True, |
| 307 | + tie_word_embeddings=True, |
| 308 | + rope_theta=10000.0, |
| 309 | + attention_bias=False, |
| 310 | + attention_dropout=0.0, |
| 311 | + ).to_dict(), |
| 312 | + image_token_index=5, # NOTE: outside the vocab size |
| 313 | + boi_token_index=4, |
| 314 | + eoi_token_index=6, |
| 315 | + attn_implementation="eager", |
| 316 | + vocab_size=32000, |
| 317 | + projection_dim=1024, |
| 318 | + ), |
| 319 | + ) |
| 320 | + |
257 | 321 |
|
258 | 322 | if QWEN2_VL_AVAILABLE: |
259 | 323 | MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig( |
@@ -425,6 +489,26 @@ def create_processor(model_name: str): |
425 | 489 | image_processor = SiglipImageProcessor(size={"height": 224, "width": 224}, image_seq_length=256) |
426 | 490 | return PaliGemmaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer) |
427 | 491 |
|
| 492 | + elif model_name.startswith("mini_gemma3"): |
| 493 | + tokenizer_config = load_tokenizer_config( |
| 494 | + os.path.join( |
| 495 | + FAKE_CONFIGS_PATH, |
| 496 | + "Google/Gemma3/gemma-3-4b-it/tokenizer_config.json", |
| 497 | + ) |
| 498 | + ) |
| 499 | + tokenizer_base = train_bpe_tokenizer( |
| 500 | + [ |
| 501 | + token.content |
| 502 | + for key, token in sorted( |
| 503 | + tokenizer_config["added_tokens_decoder"].items(), |
| 504 | + key=lambda x: int(x[0]), |
| 505 | + ) |
| 506 | + ] |
| 507 | + ) |
| 508 | + fast_tokenizer = GemmaTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config) |
| 509 | + image_processor = Gemma3ImageProcessor() |
| 510 | + return Gemma3Processor(image_processor=image_processor, tokenizer=fast_tokenizer) |
| 511 | + |
428 | 512 | else: |
429 | 513 | raise ValueError(f"Processor not available for model {model_name}") |
430 | 514 |
|
@@ -652,6 +736,25 @@ def run_mini_model_multimodal( |
652 | 736 | ), |
653 | 737 | ], |
654 | 738 | ), |
| 739 | + pytest.param( |
| 740 | + "mini_gemma3", |
| 741 | + 32, |
| 742 | + 1e-4, |
| 743 | + torch.bfloat16, |
| 744 | + 1e-3, |
| 745 | + 1e-2, |
| 746 | + 1e-1, |
| 747 | + 1e-2, |
| 748 | + 1e-2, |
| 749 | + 1e-2, |
| 750 | + marks=[ |
| 751 | + pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"), |
| 752 | + pytest.mark.skipif( |
| 753 | + not GEMMA3_AVAILABLE, |
| 754 | + reason="Gemma3 not available in this version of transformers", |
| 755 | + ), |
| 756 | + ], |
| 757 | + ), |
655 | 758 | ], |
656 | 759 | ) |
657 | 760 | def test_mini_model_multimodal( |
|
0 commit comments