@@ -672,131 +672,6 @@ def _spilt_predict(self, interaction, batch_size):
672672 return torch .cat (result_list , dim = 0 )
673673
674674
675- class ALSTrainer (Trainer ):
676- r"""ALSTrainer is designed for the ALS model of the implicit library: https://benfred.github.io/implicit"""
677-
678- def __init__ (self , config , model ):
679- super (ALSTrainer , self ).__init__ (config , model )
680-
681- def fit (
682- self ,
683- train_data ,
684- valid_data = None ,
685- verbose = True ,
686- saved = True ,
687- show_progress = False ,
688- callback_fn = None ,
689- ):
690- r"""Train the model based on the train data and the valid data.
691-
692- Args:
693- train_data (DataLoader): the train data
694- valid_data (DataLoader, optional): the valid data, default: None.
695- If it's None, the early_stopping is invalid.
696- verbose (bool, optional): whether to write training and evaluation information to logger, default: True
697- saved (bool, optional): whether to save the model parameters, default: True
698- show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
699- callback_fn (callable): Optional callback function executed at end of epoch.
700- Includes (epoch_idx, valid_score) input arguments.
701-
702- Returns:
703- (float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
704- """
705- if saved and self .start_epoch >= self .epochs :
706- self ._save_checkpoint (- 1 , verbose = verbose )
707-
708- self .eval_collector .data_collect (train_data )
709- if self .config ["train_neg_sample_args" ].get ("dynamic" , False ):
710- train_data .get_model (self .model )
711- valid_step = 0
712-
713- for epoch_idx in range (self .start_epoch , self .epochs ):
714- # train
715- training_start_time = time ()
716- # pass entire dataset as sparse csr, as required in https://benfred.github.io/implicit
717- train_loss = self .model .calculate_loss (
718- train_data ._dataset .inter_matrix (form = "csr" )
719- )
720- self .train_loss_dict [epoch_idx ] = (
721- sum (train_loss ) if isinstance (train_loss , tuple ) else train_loss
722- )
723- training_end_time = time ()
724- train_loss_output = self ._generate_train_loss_output (
725- epoch_idx , training_start_time , training_end_time , train_loss
726- )
727- if verbose :
728- self .logger .info (train_loss_output )
729- self ._add_train_loss_to_tensorboard (epoch_idx , train_loss )
730- self .wandblogger .log_metrics (
731- {"epoch" : epoch_idx , "train_loss" : train_loss , "train_step" : epoch_idx },
732- head = "train" ,
733- )
734-
735- # eval
736- if self .eval_step <= 0 or not valid_data :
737- if saved :
738- self ._save_checkpoint (epoch_idx , verbose = verbose )
739- continue
740- if (epoch_idx + 1 ) % self .eval_step == 0 :
741- valid_start_time = time ()
742- valid_score , valid_result = self ._valid_epoch (
743- valid_data , show_progress = show_progress
744- )
745-
746- (
747- self .best_valid_score ,
748- self .cur_step ,
749- stop_flag ,
750- update_flag ,
751- ) = early_stopping (
752- valid_score ,
753- self .best_valid_score ,
754- self .cur_step ,
755- max_step = self .stopping_step ,
756- bigger = self .valid_metric_bigger ,
757- )
758- valid_end_time = time ()
759- valid_score_output = (
760- set_color ("epoch %d evaluating" , "green" )
761- + " ["
762- + set_color ("time" , "blue" )
763- + ": %.2fs, "
764- + set_color ("valid_score" , "blue" )
765- + ": %f]"
766- ) % (epoch_idx , valid_end_time - valid_start_time , valid_score )
767- valid_result_output = (
768- set_color ("valid result" , "blue" ) + ": \n " + dict2str (valid_result )
769- )
770- if verbose :
771- self .logger .info (valid_score_output )
772- self .logger .info (valid_result_output )
773- self .tensorboard .add_scalar ("Vaild_score" , valid_score , epoch_idx )
774- self .wandblogger .log_metrics (
775- {** valid_result , "valid_step" : valid_step }, head = "valid"
776- )
777-
778- if update_flag :
779- if saved :
780- self ._save_checkpoint (epoch_idx , verbose = verbose )
781- self .best_valid_result = valid_result
782-
783- if callback_fn :
784- callback_fn (epoch_idx , valid_score )
785-
786- if stop_flag :
787- stop_output = "Finished training, best eval result in epoch %d" % (
788- epoch_idx - self .cur_step * self .eval_step
789- )
790- if verbose :
791- self .logger .info (stop_output )
792- break
793-
794- valid_step += 1
795-
796- self ._add_hparam_to_tensorboard (self .best_valid_score )
797- return self .best_valid_score , self .best_valid_result
798-
799-
800675class KGTrainer (Trainer ):
801676 r"""KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the
802677 recommendation related task and knowledge related task alternately.
0 commit comments