|
31 | 31 | from tunix.generate import mappings |
32 | 32 | from tunix.generate import sampler as vanilla_sampler |
33 | 33 | from tunix.generate import vllm_sampler |
| 34 | +from tunix.models.dummy_model_creator import create_dummy_model |
34 | 35 | from tunix.models.llama3 import model as llama_lib |
35 | 36 | from tunix.models.llama3 import params as llama_params |
36 | 37 | from tunix.sft import utils as base_utils |
@@ -357,6 +358,169 @@ async def dispatch_requests(): |
357 | 358 | ), |
358 | 359 | ) |
359 | 360 |
|
| 361 | + def test_vllm_sampler_sampling_kwargs(self): |
| 362 | + """Test that sampling kwargs are correctly applied to sampling_params.""" |
| 363 | + tunix_model = create_dummy_model( |
| 364 | + model_class=llama_lib.Llama3, |
| 365 | + config=llama_lib.ModelConfig.llama3p2_1b(), |
| 366 | + mesh=self.mesh, |
| 367 | + random_seed=3, |
| 368 | + ) |
| 369 | + |
| 370 | + model_tokenizer = transformers.AutoTokenizer.from_pretrained( |
| 371 | + self.model_path |
| 372 | + ) |
| 373 | + |
| 374 | + prompts = ["Hello, my name is Tom."] |
| 375 | + inputs = tc.batch_templatize(prompts, model_tokenizer) |
| 376 | + |
| 377 | + mapping_config = mappings.MappingConfig.build(tunix_model) |
| 378 | + |
| 379 | + # Test 1: Config sampling_kwargs are applied |
| 380 | + config_sampling_kwargs = { |
| 381 | + "frequency_penalty": 0.5, |
| 382 | + "presence_penalty": 0.3, |
| 383 | + } |
| 384 | + |
| 385 | + vllm_config = vllm_sampler.VllmConfig( |
| 386 | + mesh=self.mesh, |
| 387 | + hbm_utilization=0.2, |
| 388 | + init_with_random_weights=True, |
| 389 | + tpu_backend_type="jax", |
| 390 | + mapping_config=mapping_config, |
| 391 | + server_mode=False, |
| 392 | + sampling_kwargs=config_sampling_kwargs, |
| 393 | + engine_kwargs={ |
| 394 | + "model": self.model_path, |
| 395 | + "max_model_len": 512, |
| 396 | + "enable_prefix_caching": True, |
| 397 | + }, |
| 398 | + ) |
| 399 | + |
| 400 | + vl_sampler = vllm_sampler.VllmSampler( |
| 401 | + tokenizer=model_tokenizer, |
| 402 | + config=vllm_config, |
| 403 | + ) |
| 404 | + |
| 405 | + state = nnx.state(tunix_model) |
| 406 | + vl_sampler.load_checkpoint(state) |
| 407 | + |
| 408 | + # Mock the generate method to capture sampling_params |
| 409 | + original_generate = vl_sampler.llm.generate |
| 410 | + captured_sampling_params = [] |
| 411 | + |
| 412 | + def mock_generate(prompts, sampling_params, **kwargs): |
| 413 | + captured_sampling_params.append(sampling_params) |
| 414 | + return original_generate(prompts, sampling_params, **kwargs) |
| 415 | + |
| 416 | + vl_sampler.llm.generate = mock_generate |
| 417 | + |
| 418 | + # Call with additional method kwargs |
| 419 | + method_sampling_kwargs = {"min_tokens": 10} |
| 420 | + vl_sampler( |
| 421 | + input_strings=inputs, |
| 422 | + max_generation_steps=128, |
| 423 | + max_prompt_length=None, |
| 424 | + temperature=0.0, |
| 425 | + top_k=1, |
| 426 | + seed=0, |
| 427 | + echo=False, |
| 428 | + pad_output=True, |
| 429 | + **method_sampling_kwargs, |
| 430 | + ) |
| 431 | + |
| 432 | + # Verify that both config and method kwargs were applied |
| 433 | + self.assertLen(captured_sampling_params, 1) |
| 434 | + sampling_params = captured_sampling_params[0] |
| 435 | + |
| 436 | + # Check config kwargs |
| 437 | + self.assertEqual(sampling_params.frequency_penalty, 0.5) |
| 438 | + self.assertEqual(sampling_params.presence_penalty, 0.3) |
| 439 | + |
| 440 | + # Check method kwargs |
| 441 | + self.assertEqual(sampling_params.min_tokens, 10) |
| 442 | + |
| 443 | + def test_vllm_sampler_sampling_kwargs_override(self): |
| 444 | + """Test that method kwargs override config sampling_kwargs.""" |
| 445 | + tunix_model = create_dummy_model( |
| 446 | + model_class=llama_lib.Llama3, |
| 447 | + config=llama_lib.ModelConfig.llama3p2_1b(), |
| 448 | + mesh=self.mesh, |
| 449 | + random_seed=3, |
| 450 | + ) |
| 451 | + |
| 452 | + model_tokenizer = transformers.AutoTokenizer.from_pretrained( |
| 453 | + self.model_path |
| 454 | + ) |
| 455 | + |
| 456 | + prompts = ["Hello, my name is Tom."] |
| 457 | + inputs = tc.batch_templatize(prompts, model_tokenizer) |
| 458 | + |
| 459 | + mapping_config = mappings.MappingConfig.build(tunix_model) |
| 460 | + |
| 461 | + # Config has frequency_penalty = 0.5 |
| 462 | + config_sampling_kwargs = { |
| 463 | + "frequency_penalty": 0.5, |
| 464 | + "presence_penalty": 0.3, |
| 465 | + } |
| 466 | + |
| 467 | + vllm_config = vllm_sampler.VllmConfig( |
| 468 | + mesh=self.mesh, |
| 469 | + hbm_utilization=0.2, |
| 470 | + init_with_random_weights=True, |
| 471 | + tpu_backend_type="jax", |
| 472 | + mapping_config=mapping_config, |
| 473 | + server_mode=False, |
| 474 | + sampling_kwargs=config_sampling_kwargs, |
| 475 | + engine_kwargs={ |
| 476 | + "model": self.model_path, |
| 477 | + "max_model_len": 512, |
| 478 | + "enable_prefix_caching": True, |
| 479 | + }, |
| 480 | + ) |
| 481 | + |
| 482 | + vl_sampler = vllm_sampler.VllmSampler( |
| 483 | + tokenizer=model_tokenizer, |
| 484 | + config=vllm_config, |
| 485 | + ) |
| 486 | + |
| 487 | + state = nnx.state(tunix_model) |
| 488 | + vl_sampler.load_checkpoint(state) |
| 489 | + |
| 490 | + # Mock the generate method to capture sampling_params |
| 491 | + original_generate = vl_sampler.llm.generate |
| 492 | + captured_sampling_params = [] |
| 493 | + |
| 494 | + def mock_generate(prompts, sampling_params, **kwargs): |
| 495 | + captured_sampling_params.append(sampling_params) |
| 496 | + return original_generate(prompts, sampling_params, **kwargs) |
| 497 | + |
| 498 | + vl_sampler.llm.generate = mock_generate |
| 499 | + |
| 500 | + # Call with method kwargs that override config kwargs |
| 501 | + method_sampling_kwargs = {"frequency_penalty": 0.8} # Override from 0.5 to 0.8 |
| 502 | + vl_sampler( |
| 503 | + input_strings=inputs, |
| 504 | + max_generation_steps=128, |
| 505 | + max_prompt_length=None, |
| 506 | + temperature=0.0, |
| 507 | + top_k=1, |
| 508 | + seed=0, |
| 509 | + echo=False, |
| 510 | + pad_output=True, |
| 511 | + **method_sampling_kwargs, |
| 512 | + ) |
| 513 | + |
| 514 | + # Verify that method kwargs override config kwargs |
| 515 | + self.assertLen(captured_sampling_params, 1) |
| 516 | + sampling_params = captured_sampling_params[0] |
| 517 | + |
| 518 | + # Check that method kwarg overrides config kwarg |
| 519 | + self.assertEqual(sampling_params.frequency_penalty, 0.8) |
| 520 | + |
| 521 | + # Check that other config kwargs are still applied |
| 522 | + self.assertEqual(sampling_params.presence_penalty, 0.3) |
| 523 | + |
360 | 524 |
|
361 | 525 | class VllmSamplerConfigTest(absltest.TestCase): |
362 | 526 | """Unit tests for VllmSampler config plumbing (no hardware required).""" |
|
0 commit comments