|
28 | 28 | from torch.nn.utils.clip_grad import clip_grad_norm_ |
29 | 29 | from tqdm import tqdm |
30 | 30 | import torch.cuda.amp as amp |
| 31 | +import scipy.sparse as sp |
31 | 32 |
|
32 | 33 | from recbole.data.interaction import Interaction |
33 | 34 | from recbole.data.dataloader import FullSortEvalDataLoader |
@@ -92,7 +93,6 @@ def sync_grad_loss(self): |
92 | 93 | sync_loss += torch.sum(params) * 0 |
93 | 94 | return sync_loss |
94 | 95 |
|
95 | | - |
96 | 96 | class Trainer(AbstractTrainer): |
97 | 97 | r"""The basic Trainer for basic training and evaluation strategies in recommender systems. This class defines common |
98 | 98 | functions for training and evaluation processes of most recommender system models, including fit(), evaluate(), |
@@ -671,6 +671,127 @@ def _spilt_predict(self, interaction, batch_size): |
671 | 671 | result_list.append(result) |
672 | 672 | return torch.cat(result_list, dim=0) |
673 | 673 |
|
| 674 | +class ALSTrainer(Trainer): |
| 675 | + r"""ALSTrainer is designed for the ALS model of the implicit library: https://benfred.github.io/implicit""" |
| 676 | + |
| 677 | + def __init__(self, config, model): |
| 678 | + super(ALSTrainer, self).__init__(config, model) |
| 679 | + |
| 680 | + def fit( |
| 681 | + self, |
| 682 | + train_data, |
| 683 | + valid_data=None, |
| 684 | + verbose=True, |
| 685 | + saved=True, |
| 686 | + show_progress=False, |
| 687 | + callback_fn=None, |
| 688 | + ): |
| 689 | + r"""Train the model based on the train data and the valid data. |
| 690 | +
|
| 691 | + Args: |
| 692 | + train_data (DataLoader): the train data |
| 693 | + valid_data (DataLoader, optional): the valid data, default: None. |
| 694 | + If it's None, the early_stopping is invalid. |
| 695 | + verbose (bool, optional): whether to write training and evaluation information to logger, default: True |
| 696 | + saved (bool, optional): whether to save the model parameters, default: True |
| 697 | + show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``. |
| 698 | + callback_fn (callable): Optional callback function executed at end of epoch. |
| 699 | + Includes (epoch_idx, valid_score) input arguments. |
| 700 | +
|
| 701 | + Returns: |
| 702 | + (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None) |
| 703 | + """ |
| 704 | + if saved and self.start_epoch >= self.epochs: |
| 705 | + self._save_checkpoint(-1, verbose=verbose) |
| 706 | + |
| 707 | + self.eval_collector.data_collect(train_data) |
| 708 | + if self.config["train_neg_sample_args"].get("dynamic", False): |
| 709 | + train_data.get_model(self.model) |
| 710 | + valid_step = 0 |
| 711 | + |
| 712 | + for epoch_idx in range(self.start_epoch, self.epochs): |
| 713 | + # train |
| 714 | + training_start_time = time() |
| 715 | + # pass entire dataset as sparse csr, as required in https://benfred.github.io/implicit |
| 716 | + train_loss = self.model.calculate_loss(train_data._dataset.inter_matrix(form='csr')) |
| 717 | + self.train_loss_dict[epoch_idx] = ( |
| 718 | + sum(train_loss) if isinstance(train_loss, tuple) else train_loss |
| 719 | + ) |
| 720 | + training_end_time = time() |
| 721 | + train_loss_output = self._generate_train_loss_output( |
| 722 | + epoch_idx, training_start_time, training_end_time, train_loss |
| 723 | + ) |
| 724 | + if verbose: |
| 725 | + self.logger.info(train_loss_output) |
| 726 | + self._add_train_loss_to_tensorboard(epoch_idx, train_loss) |
| 727 | + self.wandblogger.log_metrics( |
| 728 | + {"epoch": epoch_idx, "train_loss": train_loss, "train_step": epoch_idx}, |
| 729 | + head="train", |
| 730 | + ) |
| 731 | + |
| 732 | + # eval |
| 733 | + if self.eval_step <= 0 or not valid_data: |
| 734 | + if saved: |
| 735 | + self._save_checkpoint(epoch_idx, verbose=verbose) |
| 736 | + continue |
| 737 | + if (epoch_idx + 1) % self.eval_step == 0: |
| 738 | + valid_start_time = time() |
| 739 | + valid_score, valid_result = self._valid_epoch( |
| 740 | + valid_data, show_progress=show_progress |
| 741 | + ) |
| 742 | + |
| 743 | + ( |
| 744 | + self.best_valid_score, |
| 745 | + self.cur_step, |
| 746 | + stop_flag, |
| 747 | + update_flag, |
| 748 | + ) = early_stopping( |
| 749 | + valid_score, |
| 750 | + self.best_valid_score, |
| 751 | + self.cur_step, |
| 752 | + max_step=self.stopping_step, |
| 753 | + bigger=self.valid_metric_bigger, |
| 754 | + ) |
| 755 | + valid_end_time = time() |
| 756 | + valid_score_output = ( |
| 757 | + set_color("epoch %d evaluating", "green") |
| 758 | + + " [" |
| 759 | + + set_color("time", "blue") |
| 760 | + + ": %.2fs, " |
| 761 | + + set_color("valid_score", "blue") |
| 762 | + + ": %f]" |
| 763 | + ) % (epoch_idx, valid_end_time - valid_start_time, valid_score) |
| 764 | + valid_result_output = ( |
| 765 | + set_color("valid result", "blue") + ": \n" + dict2str(valid_result) |
| 766 | + ) |
| 767 | + if verbose: |
| 768 | + self.logger.info(valid_score_output) |
| 769 | + self.logger.info(valid_result_output) |
| 770 | + self.tensorboard.add_scalar("Vaild_score", valid_score, epoch_idx) |
| 771 | + self.wandblogger.log_metrics( |
| 772 | + {**valid_result, "valid_step": valid_step}, head="valid" |
| 773 | + ) |
| 774 | + |
| 775 | + if update_flag: |
| 776 | + if saved: |
| 777 | + self._save_checkpoint(epoch_idx, verbose=verbose) |
| 778 | + self.best_valid_result = valid_result |
| 779 | + |
| 780 | + if callback_fn: |
| 781 | + callback_fn(epoch_idx, valid_score) |
| 782 | + |
| 783 | + if stop_flag: |
| 784 | + stop_output = "Finished training, best eval result in epoch %d" % ( |
| 785 | + epoch_idx - self.cur_step * self.eval_step |
| 786 | + ) |
| 787 | + if verbose: |
| 788 | + self.logger.info(stop_output) |
| 789 | + break |
| 790 | + |
| 791 | + valid_step += 1 |
| 792 | + |
| 793 | + self._add_hparam_to_tensorboard(self.best_valid_score) |
| 794 | + return self.best_valid_score, self.best_valid_result |
674 | 795 |
|
675 | 796 | class KGTrainer(Trainer): |
676 | 797 | r"""KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the |
|
0 commit comments