Skip to content

Commit 7e8e0b6

Browse files
authored
Merge pull request #2144 from BishopLiu/master
remove redundant class
2 parents 9a6f63d + d34bcd5 commit 7e8e0b6

File tree

1 file changed

+0
-125
lines changed

1 file changed

+0
-125
lines changed

recbole/trainer/trainer.py

Lines changed: 0 additions & 125 deletions
Original file line numberDiff line numberDiff line change
@@ -672,131 +672,6 @@ def _spilt_predict(self, interaction, batch_size):
672672
return torch.cat(result_list, dim=0)
673673

674674

675-
class ALSTrainer(Trainer):
676-
r"""ALSTrainer is designed for the ALS model of the implicit library: https://benfred.github.io/implicit"""
677-
678-
def __init__(self, config, model):
679-
super(ALSTrainer, self).__init__(config, model)
680-
681-
def fit(
682-
self,
683-
train_data,
684-
valid_data=None,
685-
verbose=True,
686-
saved=True,
687-
show_progress=False,
688-
callback_fn=None,
689-
):
690-
r"""Train the model based on the train data and the valid data.
691-
692-
Args:
693-
train_data (DataLoader): the train data
694-
valid_data (DataLoader, optional): the valid data, default: None.
695-
If it's None, the early_stopping is invalid.
696-
verbose (bool, optional): whether to write training and evaluation information to logger, default: True
697-
saved (bool, optional): whether to save the model parameters, default: True
698-
show_progress (bool): Show the progress of training epoch and evaluate epoch. Defaults to ``False``.
699-
callback_fn (callable): Optional callback function executed at end of epoch.
700-
Includes (epoch_idx, valid_score) input arguments.
701-
702-
Returns:
703-
(float, dict): best valid score and best valid result. If valid_data is None, it returns (-1, None)
704-
"""
705-
if saved and self.start_epoch >= self.epochs:
706-
self._save_checkpoint(-1, verbose=verbose)
707-
708-
self.eval_collector.data_collect(train_data)
709-
if self.config["train_neg_sample_args"].get("dynamic", False):
710-
train_data.get_model(self.model)
711-
valid_step = 0
712-
713-
for epoch_idx in range(self.start_epoch, self.epochs):
714-
# train
715-
training_start_time = time()
716-
# pass entire dataset as sparse csr, as required in https://benfred.github.io/implicit
717-
train_loss = self.model.calculate_loss(
718-
train_data._dataset.inter_matrix(form="csr")
719-
)
720-
self.train_loss_dict[epoch_idx] = (
721-
sum(train_loss) if isinstance(train_loss, tuple) else train_loss
722-
)
723-
training_end_time = time()
724-
train_loss_output = self._generate_train_loss_output(
725-
epoch_idx, training_start_time, training_end_time, train_loss
726-
)
727-
if verbose:
728-
self.logger.info(train_loss_output)
729-
self._add_train_loss_to_tensorboard(epoch_idx, train_loss)
730-
self.wandblogger.log_metrics(
731-
{"epoch": epoch_idx, "train_loss": train_loss, "train_step": epoch_idx},
732-
head="train",
733-
)
734-
735-
# eval
736-
if self.eval_step <= 0 or not valid_data:
737-
if saved:
738-
self._save_checkpoint(epoch_idx, verbose=verbose)
739-
continue
740-
if (epoch_idx + 1) % self.eval_step == 0:
741-
valid_start_time = time()
742-
valid_score, valid_result = self._valid_epoch(
743-
valid_data, show_progress=show_progress
744-
)
745-
746-
(
747-
self.best_valid_score,
748-
self.cur_step,
749-
stop_flag,
750-
update_flag,
751-
) = early_stopping(
752-
valid_score,
753-
self.best_valid_score,
754-
self.cur_step,
755-
max_step=self.stopping_step,
756-
bigger=self.valid_metric_bigger,
757-
)
758-
valid_end_time = time()
759-
valid_score_output = (
760-
set_color("epoch %d evaluating", "green")
761-
+ " ["
762-
+ set_color("time", "blue")
763-
+ ": %.2fs, "
764-
+ set_color("valid_score", "blue")
765-
+ ": %f]"
766-
) % (epoch_idx, valid_end_time - valid_start_time, valid_score)
767-
valid_result_output = (
768-
set_color("valid result", "blue") + ": \n" + dict2str(valid_result)
769-
)
770-
if verbose:
771-
self.logger.info(valid_score_output)
772-
self.logger.info(valid_result_output)
773-
self.tensorboard.add_scalar("Vaild_score", valid_score, epoch_idx)
774-
self.wandblogger.log_metrics(
775-
{**valid_result, "valid_step": valid_step}, head="valid"
776-
)
777-
778-
if update_flag:
779-
if saved:
780-
self._save_checkpoint(epoch_idx, verbose=verbose)
781-
self.best_valid_result = valid_result
782-
783-
if callback_fn:
784-
callback_fn(epoch_idx, valid_score)
785-
786-
if stop_flag:
787-
stop_output = "Finished training, best eval result in epoch %d" % (
788-
epoch_idx - self.cur_step * self.eval_step
789-
)
790-
if verbose:
791-
self.logger.info(stop_output)
792-
break
793-
794-
valid_step += 1
795-
796-
self._add_hparam_to_tensorboard(self.best_valid_score)
797-
return self.best_valid_score, self.best_valid_result
798-
799-
800675
class KGTrainer(Trainer):
801676
r"""KGTrainer is designed for Knowledge-aware recommendation methods. Some of these models need to train the
802677
recommendation related task and knowledge related task alternately.

0 commit comments

Comments
 (0)