|
34 | 34 | from tunix.sft import hooks |
35 | 35 | from tunix.sft import peft_trainer |
36 | 36 | from tunix.sft import profiler |
| 37 | +from tunix.sft import utils |
37 | 38 | from tunix.tests import test_common as tc |
38 | 39 | from tunix.utils import compat |
39 | 40 |
|
@@ -634,7 +635,67 @@ def _post_process_eval_step(self, aux): |
634 | 635 | self.assertEqual(train_invoke, {'foo': 2, 'bar': 4}) |
635 | 636 | self.assertEqual(eval_invoke, {'foo': 1, 'bar': 16}) |
636 | 637 |
|
| 638 | + def test_loss_output_format(self): |
| 639 | + def custom_loss_fn( |
| 640 | + model: nnx.Module, |
| 641 | + input_tokens: jax.Array, |
| 642 | + input_mask: jax.Array, |
| 643 | + positions: jax.Array, |
| 644 | + attention_mask: jax.Array, |
| 645 | + images: jax.Array | None = None, |
| 646 | + ) -> utils.LossOutput: |
| 647 | + del model, input_tokens, input_mask, positions, attention_mask, images |
| 648 | + return utils.LossOutput( |
| 649 | + primary_loss=utils.WeightedMetric( |
| 650 | + jnp.array(2.0, dtype=jnp.float32), |
| 651 | + jnp.array(2.0, dtype=jnp.float32), |
| 652 | + ), |
| 653 | + aux_metrics={ |
| 654 | + 'foo': utils.WeightedMetric( |
| 655 | + jnp.array(10.0, dtype=jnp.float32), |
| 656 | + jnp.array(5.0, dtype=jnp.float32), |
| 657 | + ), |
| 658 | + 'bar': utils.WeightedMetric( |
| 659 | + jnp.array(6.0, dtype=jnp.float32), |
| 660 | + jnp.array(2.0, dtype=jnp.float32), |
| 661 | + ), |
| 662 | + }, |
| 663 | + ) |
| 664 | + |
| 665 | + train_invoke = {'foo': 0.0, 'bar': 0.0} |
| 666 | + eval_invoke = {'foo': 0.0, 'bar': 0.0} |
| 667 | + |
| 668 | + class CustomTrainer(peft_trainer.PeftTrainer): |
| 669 | + |
| 670 | + def _post_process_train_step(self, aux): |
| 671 | + train_invoke['foo'] += aux['foo'] |
| 672 | + train_invoke['bar'] += aux['bar'] |
| 673 | + |
| 674 | + def _post_process_eval_step(self, aux): |
| 675 | + eval_invoke['foo'] += aux['foo'] |
| 676 | + eval_invoke['bar'] += aux['bar'] |
| 677 | + |
| 678 | + config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100) |
| 679 | + model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) |
| 680 | + |
| 681 | + trainer = CustomTrainer(model, optax.sgd(1e-3), config) |
| 682 | + trainer = trainer.with_gen_model_input_fn( |
| 683 | + dummy_gen_model_input_fn |
| 684 | + ).with_loss_fn( |
| 685 | + custom_loss_fn |
| 686 | + ) # Note: has_aux=False is default but LossOutput returns aux natively |
| 687 | + |
| 688 | + trainer.train(self.train_ds, self.eval_ds) |
| 689 | + # The dataset provides 2 training steps. |
| 690 | + # foo = 10.0 / 5.0 = 2.0 per step. |
| 691 | + # bar = 6.0 / 2.0 = 3.0 per step. |
| 692 | + self.assertEqual(train_invoke, {'foo': 4.0, 'bar': 6.0}) |
| 693 | + |
| 694 | + # Since eval_ds is length 2, it evaluates at step 2. |
| 695 | + self.assertEqual(eval_invoke, {'foo': 8.0, 'bar': 12.0}) |
| 696 | + |
637 | 697 | def test_injected_params(self): |
| 698 | + |
638 | 699 | config = peft_trainer.TrainingConfig(eval_every_n_steps=2, max_steps=100) |
639 | 700 | model = tc.ToyTransformer(config=tc.ModelConfig(), rngs=nnx.Rngs(0)) |
640 | 701 |
|
|
0 commit comments