Skip to content

Commit 092e425

Browse files
author
eljandoubi
committed
add bf 16 test
1 parent 28566dc commit 092e425

4 files changed

Lines changed: 308 additions & 87 deletions

File tree

test/convergence/bf16/test_mini_models.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from liger_kernel.transformers import apply_liger_kernel_to_gemma
2222
from liger_kernel.transformers import apply_liger_kernel_to_gemma2
23+
from liger_kernel.transformers import apply_liger_kernel_to_gemma3
2324
from liger_kernel.transformers import apply_liger_kernel_to_granite
2425
from liger_kernel.transformers import apply_liger_kernel_to_llama
2526
from liger_kernel.transformers import apply_liger_kernel_to_mistral
@@ -35,6 +36,7 @@
3536
from test.utils import assert_verbose_allclose
3637
from test.utils import revert_liger_kernel_to_gemma
3738
from test.utils import revert_liger_kernel_to_gemma2
39+
from test.utils import revert_liger_kernel_to_gemma3
3840
from test.utils import revert_liger_kernel_to_granite
3941
from test.utils import revert_liger_kernel_to_llama
4042
from test.utils import revert_liger_kernel_to_mistral
@@ -93,6 +95,14 @@
9395
except ImportError:
9496
OLMO2_AVAILABLE = False
9597

98+
try:
99+
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
100+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
101+
102+
GEMMA3_AVAILABLE = True
103+
except ImportError:
104+
GEMMA3_AVAILABLE = False
105+
96106
from liger_kernel.utils import infer_device
97107

98108
device = infer_device()
@@ -326,6 +336,36 @@
326336
),
327337
}
328338

339+
if GEMMA3_AVAILABLE:
340+
MINI_MODEL_SETUPS["mini_gemma3"] = MiniModelConfig(
341+
liger_kernel_patch_func=apply_liger_kernel_to_gemma3,
342+
liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3,
343+
model_class=Gemma3ForCausalLM,
344+
mini_model_config=Gemma3TextConfig(
345+
vocab_size=32000, # 262144
346+
hidden_size=1024, # 1152
347+
intermediate_size=2048, # 6912
348+
num_hidden_layers=4, # 26
349+
num_attention_heads=4,
350+
num_key_value_heads=1,
351+
head_dim=256,
352+
hidden_activation="gelu_pytorch_tanh",
353+
max_position_embeddings=8192, # 32768
354+
initializer_range=0.02,
355+
rms_norm_eps=1e-06,
356+
use_cache=True,
357+
pad_token_id=0,
358+
bos_token_id=2,
359+
eos_token_id=1,
360+
tie_word_embeddings=True,
361+
rope_theta=10000.0, # 1000000
362+
attention_bias=False,
363+
attention_dropout=0.0,
364+
attn_implementation="eager",
365+
),
366+
)
367+
368+
329369
if MLLAMA_AVAILABLE:
330370
MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig(
331371
liger_kernel_patch_func=apply_liger_kernel_to_mllama,
@@ -816,6 +856,25 @@ def run_mini_model(
816856
# not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
817857
# ),
818858
# ),
859+
pytest.param(
860+
"mini_gemma3",
861+
32,
862+
1e-4,
863+
torch.bfloat16,
864+
1e-3,
865+
1e-2,
866+
1e-1,
867+
1e-2,
868+
1e-2,
869+
1e-2,
870+
marks=[
871+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
872+
pytest.mark.skipif(
873+
not GEMMA3_AVAILABLE,
874+
reason="Gemma3 not available in this version of transformers",
875+
),
876+
],
877+
),
819878
],
820879
)
821880
def test_mini_model(

test/convergence/bf16/test_mini_models_multimodal.py

Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,10 @@
77
from datasets import load_dataset
88
from torch.utils.data import DataLoader
99
from transformers import PreTrainedTokenizerFast
10+
from transformers.models.gemma.tokenization_gemma_fast import GemmaTokenizerFast
11+
from transformers.models.siglip.configuration_siglip import SiglipVisionConfig
1012

13+
from liger_kernel.transformers import apply_liger_kernel_to_gemma3
1114
from liger_kernel.transformers import apply_liger_kernel_to_mllama
1215
from liger_kernel.transformers import apply_liger_kernel_to_paligemma
1316
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_5_vl
@@ -18,6 +21,7 @@
1821
from test.utils import assert_verbose_allclose
1922
from test.utils import load_tokenizer_config
2023
from test.utils import multimodal_collate_fn
24+
from test.utils import revert_liger_kernel_to_gemma3
2125
from test.utils import revert_liger_kernel_to_mllama
2226
from test.utils import revert_liger_kernel_to_Paligemma
2327
from test.utils import revert_liger_kernel_to_qwen2_5_vl
@@ -80,6 +84,18 @@
8084
except ImportError:
8185
PALIGEMMA_AVAILABLE = False
8286

87+
try:
88+
# Gemma3 is only available in transformers>=4.50.0
89+
from transformers.models.gemma3.configuration_gemma3 import Gemma3Config
90+
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
91+
from transformers.models.gemma3.image_processing_gemma3 import Gemma3ImageProcessor
92+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForConditionalGeneration
93+
from transformers.models.gemma3.processing_gemma3 import Gemma3Processor
94+
95+
GEMMA3_AVAILABLE = True
96+
except ImportError:
97+
GEMMA3_AVAILABLE = False
98+
8399
from liger_kernel.utils import infer_device
84100

85101
device = infer_device()
@@ -254,6 +270,54 @@
254270
),
255271
)
256272

273+
if GEMMA3_AVAILABLE:
274+
MINI_MODEL_SETUPS["mini_gemma3"] = MiniModelConfig(
275+
liger_kernel_patch_func=functools.partial(apply_liger_kernel_to_gemma3, fused_linear_cross_entropy=False),
276+
liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3,
277+
model_class=Gemma3ForConditionalGeneration,
278+
mini_model_config=Gemma3Config(
279+
vision_config=SiglipVisionConfig(
280+
attention_dropout=0.0,
281+
hidden_act="gelu_pytorch_tanh",
282+
hidden_size=1152,
283+
image_size=224,
284+
intermediate_size=2048, # 4304
285+
layer_norm_eps=1e-06,
286+
num_attention_heads=4, # 16
287+
num_channels=3,
288+
num_hidden_layers=4, # 27
289+
num_image_tokens=256,
290+
num_positions=256,
291+
patch_size=14,
292+
projection_dim=1024, # 2304
293+
).to_dict(),
294+
text_config=Gemma3TextConfig(
295+
vocab_size=32000, # 256000
296+
hidden_size=1024, # 3072
297+
intermediate_size=2048, # 24576
298+
num_hidden_layers=4, # 28
299+
num_attention_heads=4, # 16
300+
num_key_value_heads=4, # 16
301+
head_dim=256,
302+
hidden_activation="gelu_pytorch_tanh",
303+
max_position_embeddings=8192,
304+
initializer_range=0.02,
305+
rms_norm_eps=1e-06,
306+
use_cache=True,
307+
tie_word_embeddings=True,
308+
rope_theta=10000.0,
309+
attention_bias=False,
310+
attention_dropout=0.0,
311+
).to_dict(),
312+
image_token_index=5, # NOTE: outside the vocab size
313+
boi_token_index=4,
314+
eoi_token_index=6,
315+
attn_implementation="eager",
316+
vocab_size=32000,
317+
projection_dim=1024,
318+
),
319+
)
320+
257321

258322
if QWEN2_VL_AVAILABLE:
259323
MINI_MODEL_SETUPS["mini_qwen2_vl"] = MiniModelConfig(
@@ -425,6 +489,26 @@ def create_processor(model_name: str):
425489
image_processor = SiglipImageProcessor(size={"height": 224, "width": 224}, image_seq_length=256)
426490
return PaliGemmaProcessor(image_processor=image_processor, tokenizer=fast_tokenizer)
427491

492+
elif model_name.startswith("mini_gemma3"):
493+
tokenizer_config = load_tokenizer_config(
494+
os.path.join(
495+
FAKE_CONFIGS_PATH,
496+
"Google/Gemma3/gemma-3-4b-it/tokenizer_config.json",
497+
)
498+
)
499+
tokenizer_base = train_bpe_tokenizer(
500+
[
501+
token.content
502+
for key, token in sorted(
503+
tokenizer_config["added_tokens_decoder"].items(),
504+
key=lambda x: int(x[0]),
505+
)
506+
]
507+
)
508+
fast_tokenizer = GemmaTokenizerFast(tokenizer_object=tokenizer_base, **tokenizer_config)
509+
image_processor = Gemma3ImageProcessor()
510+
return Gemma3Processor(image_processor=image_processor, tokenizer=fast_tokenizer)
511+
428512
else:
429513
raise ValueError(f"Processor not available for model {model_name}")
430514

@@ -652,6 +736,25 @@ def run_mini_model_multimodal(
652736
),
653737
],
654738
),
739+
pytest.param(
740+
"mini_gemma3",
741+
32,
742+
1e-4,
743+
torch.bfloat16,
744+
1e-3,
745+
1e-2,
746+
1e-1,
747+
1e-2,
748+
1e-2,
749+
1e-2,
750+
marks=[
751+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
752+
pytest.mark.skipif(
753+
not GEMMA3_AVAILABLE,
754+
reason="Gemma3 not available in this version of transformers",
755+
),
756+
],
757+
),
655758
],
656759
)
657760
def test_mini_model_multimodal(

test/convergence/bf16/test_mini_models_with_logits.py

Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from liger_kernel.transformers import apply_liger_kernel_to_gemma
2222
from liger_kernel.transformers import apply_liger_kernel_to_gemma2
23+
from liger_kernel.transformers import apply_liger_kernel_to_gemma3
2324
from liger_kernel.transformers import apply_liger_kernel_to_granite
2425
from liger_kernel.transformers import apply_liger_kernel_to_llama
2526
from liger_kernel.transformers import apply_liger_kernel_to_mistral
@@ -35,6 +36,7 @@
3536
from test.utils import assert_verbose_allclose
3637
from test.utils import revert_liger_kernel_to_gemma
3738
from test.utils import revert_liger_kernel_to_gemma2
39+
from test.utils import revert_liger_kernel_to_gemma3
3840
from test.utils import revert_liger_kernel_to_granite
3941
from test.utils import revert_liger_kernel_to_llama
4042
from test.utils import revert_liger_kernel_to_mistral
@@ -93,6 +95,14 @@
9395
except ImportError:
9496
OLMO2_AVAILABLE = False
9597

98+
try:
99+
from transformers.models.gemma3.configuration_gemma3 import Gemma3TextConfig
100+
from transformers.models.gemma3.modeling_gemma3 import Gemma3ForCausalLM
101+
102+
GEMMA3_AVAILABLE = True
103+
except ImportError:
104+
GEMMA3_AVAILABLE = False
105+
96106
from liger_kernel.utils import infer_device
97107

98108
device = infer_device()
@@ -326,6 +336,36 @@
326336
),
327337
}
328338

339+
if GEMMA3_AVAILABLE:
340+
MINI_MODEL_SETUPS["mini_gemma3"] = MiniModelConfig(
341+
liger_kernel_patch_func=apply_liger_kernel_to_gemma3,
342+
liger_kernel_patch_revert_func=revert_liger_kernel_to_gemma3,
343+
model_class=Gemma3ForCausalLM,
344+
mini_model_config=Gemma3TextConfig(
345+
vocab_size=32000, # 262144
346+
hidden_size=1024, # 1152
347+
intermediate_size=2048, # 6912
348+
num_hidden_layers=4, # 26
349+
num_attention_heads=4,
350+
num_key_value_heads=1,
351+
head_dim=256,
352+
hidden_activation="gelu_pytorch_tanh",
353+
max_position_embeddings=8192, # 32768
354+
initializer_range=0.02,
355+
rms_norm_eps=1e-06,
356+
use_cache=True,
357+
pad_token_id=0,
358+
bos_token_id=2,
359+
eos_token_id=1,
360+
tie_word_embeddings=True,
361+
rope_theta=10000.0, # 1000000
362+
attention_bias=False,
363+
attention_dropout=0.0,
364+
attn_implementation="eager",
365+
),
366+
)
367+
368+
329369
if MLLAMA_AVAILABLE:
330370
MINI_MODEL_SETUPS["mini_mllama"] = MiniModelConfig(
331371
liger_kernel_patch_func=apply_liger_kernel_to_mllama,
@@ -815,6 +855,25 @@ def run_mini_model(
815855
# not supports_bfloat16(), reason="bfloat16 not supported on this GPU"
816856
# ),
817857
# ),
858+
pytest.param(
859+
"mini_gemma3",
860+
32,
861+
1e-4,
862+
torch.bfloat16,
863+
1e-3,
864+
1e-2,
865+
1e-1,
866+
1e-2,
867+
1e-2,
868+
1e-2,
869+
marks=[
870+
pytest.mark.skipif(not supports_bfloat16(), reason="bfloat16 not supported on this GPU"),
871+
pytest.mark.skipif(
872+
not GEMMA3_AVAILABLE,
873+
reason="Gemma3 not available in this version of transformers",
874+
),
875+
],
876+
),
818877
],
819878
)
820879
def test_mini_model(

0 commit comments

Comments
 (0)