Skip to content

Commit 0614afe

Browse files
author
The tunix Authors
committed
Gemma4 sampling optimizations.
PiperOrigin-RevId: 901155118
1 parent 0fac961 commit 0614afe

3 files changed

Lines changed: 197 additions & 40 deletions

File tree

tests/generate/sampler_test.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,18 @@
1515
# Forked from flax/examples/gemma/sampler_test.py
1616

1717
import dataclasses
18+
from unittest import mock
1819
from absl.testing import absltest
1920
from absl.testing import parameterized
2021
from flax import nnx
2122
import jax
23+
import jax.numpy as jnp
2224
import numpy as np
2325
from tunix.generate import sampler as sampler_lib
2426
from tunix.generate import utils
27+
from tunix.models.gemma4 import model as gemma4_model_lib
2528
from tunix.tests import test_common as tc
2629

27-
2830
@dataclasses.dataclass(kw_only=True, frozen=True)
2931
class ModelConfigWithDtype(tc.ModelConfig):
3032
dtype: jax.numpy.dtype = jax.numpy.bfloat16
@@ -605,6 +607,56 @@ def test_forbidden_token_ids(self):
605607
self.assertLen(result.tokens[0], max_generation_steps)
606608
self.assertNoCommonElements(result.tokens[0], forbidden_tokens)
607609

610+
def test_gemma4_smoke_test(self):
611+
"""Runs a sampling call with a dummy Gemma4 config.
612+
613+
Useful to catch JAX compilation and model implementation errors early.
614+
"""
615+
config = gemma4_model_lib.ModelConfig(
616+
num_layers=2,
617+
num_embed=32,
618+
embed_dim=16,
619+
hidden_dim=16,
620+
num_heads=4,
621+
head_dim=16,
622+
num_kv_heads=1,
623+
per_layer_input_dim=16,
624+
sliding_window_size=4,
625+
param_dtype=jnp.bfloat16,
626+
attention_pattern=(
627+
gemma4_model_lib.AttentionType.LOCAL_SLIDING,
628+
gemma4_model_lib.AttentionType.LOCAL_SLIDING,
629+
gemma4_model_lib.AttentionType.LOCAL_SLIDING,
630+
gemma4_model_lib.AttentionType.GLOBAL,
631+
),
632+
final_logit_softcap=30.0,
633+
local_rope_proportion=1.0,
634+
global_rope_proportion=0.25,
635+
global_key_size=16,
636+
k_eq_v_global=False,
637+
local_base_frequency=10000,
638+
global_base_frequency=1000000,
639+
local_scale_factor=1.0,
640+
global_scale_factor=1.0,
641+
)
642+
rngs = nnx.Rngs(0)
643+
model = gemma4_model_lib.Gemma4(config, rngs=rngs)
644+
cache_config = sampler_lib.CacheConfig(
645+
cache_size=32,
646+
num_layers=config.num_layers,
647+
num_kv_heads=config.num_kv_heads,
648+
head_dim=config.head_dim,
649+
)
650+
mock_tokenizer = tc.MockVocab()
651+
mock_tokenizer.DecodeIds = mock.MagicMock()
652+
mock_tokenizer.DecodeIds.return_value = 'decoded_string'
653+
sampler = sampler_lib.Sampler(model, mock_tokenizer, cache_config)
654+
sampler(
655+
['input string', 'hello world'],
656+
max_generation_steps=10,
657+
max_prompt_length=10,
658+
)
659+
608660

609661
if __name__ == '__main__':
610662
absltest.main()

tests/models/gemma4/model_test.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ def test_forward_pass_dense(self):
3333
config.num_heads = 4
3434
config.head_dim = 64
3535
config.num_kv_heads = 1
36+
config.frac_shared_layers = 0.0
3637

3738
rngs = nnx.Rngs(0)
3839
model = model_lib.Gemma4(config, rngs=rngs)

0 commit comments

Comments
 (0)