Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 51 additions & 1 deletion tests/generate/sampler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,18 @@
# Forked from flax/examples/gemma/sampler_test.py

import dataclasses
from unittest import mock
from absl.testing import absltest
from absl.testing import parameterized
from flax import nnx
import jax
import jax.numpy as jnp
import numpy as np
from tunix.generate import sampler as sampler_lib
from tunix.generate import utils
from tunix.models.gemma4 import model as gemma4_model_lib
from tunix.tests import test_common as tc


@dataclasses.dataclass(kw_only=True, frozen=True)
class ModelConfigWithDtype(tc.ModelConfig):
dtype: jax.numpy.dtype = jax.numpy.bfloat16
Expand Down Expand Up @@ -605,6 +607,54 @@ def test_forbidden_token_ids(self):
self.assertLen(result.tokens[0], max_generation_steps)
self.assertNoCommonElements(result.tokens[0], forbidden_tokens)

def test_gemma4_smoke_test(self):
"""Runs a sampling call with a dummy Gemma4 config.

Useful to catch JAX compilation and model implementation errors early.
"""
config = gemma4_model_lib.ModelConfig(
num_layers=2,
num_embed=32,
embed_dim=16,
hidden_dim=16,
num_heads=4,
head_dim=16,
num_kv_heads=1,
per_layer_input_dim=16,
sliding_window_size=4,
param_dtype=jnp.bfloat16,
attention_pattern=(
gemma4_model_lib.AttentionType.GLOBAL,
gemma4_model_lib.AttentionType.LOCAL_SLIDING,
),
final_logit_softcap=30.0,
local_rope_proportion=1.0,
global_rope_proportion=0.25,
global_key_size=16,
k_eq_v_global=False,
local_base_frequency=10000,
global_base_frequency=1000000,
local_scale_factor=1.0,
global_scale_factor=1.0,
)
rngs = nnx.Rngs(0)
model = gemma4_model_lib.Gemma4(config, rngs=rngs)
cache_config = sampler_lib.CacheConfig(
cache_size=32,
num_layers=config.num_layers,
num_kv_heads=config.num_kv_heads,
head_dim=config.head_dim,
)
mock_tokenizer = tc.MockVocab()
mock_tokenizer.DecodeIds = mock.MagicMock()
mock_tokenizer.DecodeIds.return_value = 'decoded_string'
sampler = sampler_lib.Sampler(model, mock_tokenizer, cache_config)
sampler(
['input string', 'hello world'],
max_generation_steps=10,
max_prompt_length=10,
)


if __name__ == '__main__':
absltest.main()
1 change: 1 addition & 0 deletions tests/models/gemma4/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ def test_forward_pass_dense(self):
config.num_heads = 4
config.head_dim = 64
config.num_kv_heads = 1
config.frac_shared_layers = 0.0

rngs = nnx.Rngs(0)
model = model_lib.Gemma4(config, rngs=rngs)
Expand Down
Loading
Loading