|
15 | 15 | # Forked from flax/examples/gemma/sampler_test.py |
16 | 16 |
|
17 | 17 | import dataclasses |
| 18 | +from unittest import mock |
18 | 19 | from absl.testing import absltest |
19 | 20 | from absl.testing import parameterized |
20 | 21 | from flax import nnx |
21 | 22 | import jax |
| 23 | +import jax.numpy as jnp |
22 | 24 | import numpy as np |
23 | 25 | from tunix.generate import sampler as sampler_lib |
24 | 26 | from tunix.generate import utils |
| 27 | +from tunix.models.gemma4 import model as gemma4_model_lib |
25 | 28 | from tunix.tests import test_common as tc |
26 | 29 |
|
27 | | - |
28 | 30 | @dataclasses.dataclass(kw_only=True, frozen=True) |
29 | 31 | class ModelConfigWithDtype(tc.ModelConfig): |
30 | 32 | dtype: jax.numpy.dtype = jax.numpy.bfloat16 |
@@ -605,6 +607,56 @@ def test_forbidden_token_ids(self): |
605 | 607 | self.assertLen(result.tokens[0], max_generation_steps) |
606 | 608 | self.assertNoCommonElements(result.tokens[0], forbidden_tokens) |
607 | 609 |
|
| 610 | + def test_gemma4_smoke_test(self): |
| 611 | + """Runs a sampling call with a dummy Gemma4 config. |
| 612 | +
|
| 613 | + Useful to catch JAX compilation and model implementation errors early. |
| 614 | + """ |
| 615 | + config = gemma4_model_lib.ModelConfig( |
| 616 | + num_layers=2, |
| 617 | + num_embed=32, |
| 618 | + embed_dim=16, |
| 619 | + hidden_dim=16, |
| 620 | + num_heads=4, |
| 621 | + head_dim=16, |
| 622 | + num_kv_heads=1, |
| 623 | + per_layer_input_dim=16, |
| 624 | + sliding_window_size=4, |
| 625 | + param_dtype=jnp.bfloat16, |
| 626 | + attention_pattern=( |
| 627 | + gemma4_model_lib.AttentionType.LOCAL_SLIDING, |
| 628 | + gemma4_model_lib.AttentionType.LOCAL_SLIDING, |
| 629 | + gemma4_model_lib.AttentionType.LOCAL_SLIDING, |
| 630 | + gemma4_model_lib.AttentionType.GLOBAL, |
| 631 | + ), |
| 632 | + final_logit_softcap=30.0, |
| 633 | + local_rope_proportion=1.0, |
| 634 | + global_rope_proportion=0.25, |
| 635 | + global_key_size=16, |
| 636 | + k_eq_v_global=False, |
| 637 | + local_base_frequency=10000, |
| 638 | + global_base_frequency=1000000, |
| 639 | + local_scale_factor=1.0, |
| 640 | + global_scale_factor=1.0, |
| 641 | + ) |
| 642 | + rngs = nnx.Rngs(0) |
| 643 | + model = gemma4_model_lib.Gemma4(config, rngs=rngs) |
| 644 | + cache_config = sampler_lib.CacheConfig( |
| 645 | + cache_size=32, |
| 646 | + num_layers=config.num_layers, |
| 647 | + num_kv_heads=config.num_kv_heads, |
| 648 | + head_dim=config.head_dim, |
| 649 | + ) |
| 650 | + mock_tokenizer = tc.MockVocab() |
| 651 | + mock_tokenizer.DecodeIds = mock.MagicMock() |
| 652 | + mock_tokenizer.DecodeIds.return_value = 'decoded_string' |
| 653 | + sampler = sampler_lib.Sampler(model, mock_tokenizer, cache_config) |
| 654 | + sampler( |
| 655 | + ['input string', 'hello world'], |
| 656 | + max_generation_steps=10, |
| 657 | + max_prompt_length=10, |
| 658 | + ) |
| 659 | + |
608 | 660 |
|
609 | 661 | if __name__ == '__main__': |
610 | 662 | absltest.main() |
0 commit comments