Skip to content

Commit 0924b1b

Browse files
committed
Update ctr_trainer.py
1 parent ccbde16 commit 0924b1b

1 file changed

Lines changed: 3 additions & 1 deletion

File tree

torch_rechub/trainers/ctr_trainer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,9 @@ class CTRTrainer(object):
2121
earlystop_patience (int): how long to wait after last time validation auc improved (default=10).
2222
device (str): `"cpu"` or `"cuda:0"`
2323
gpus (list): id of multi gpu (default=[]). If the length >=1, then the model will wrapped by nn.DataParallel.
24-
loss_mode (int, optional): the training mode, `{0:point-wise, 1:pair-wise, 2:list-wise}`. Defaults to 0.
24+
loss_mode (bool): whether the model returns prediction only or prediction with extra loss.
25+
``True`` means ``model(x_dict) -> y_pred``.
26+
``False`` means ``model(x_dict) -> (y_pred, other_loss)``.
2527
model_path (str): the path you want to save the model (default="./"). Note only save the best weight in the validation data.
2628
embedding_l1 (float): L1 regularization coefficient for embedding parameters (default=0.0).
2729
embedding_l2 (float): L2 regularization coefficient for embedding parameters (default=0.0).

0 commit comments

Comments
 (0)