You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
During training and evaluation: the output is a dictionary with three elements: {"loss":torch.tensor, "labels": torch.tensor, "predictions": torch.tensor}
During inference: The output is the tensor of predictions.
2. Extend the Trainer class to support all prediction tasks:
The trainer class is now accepting a T4Rec model defined with binary or regression tasks.
Remove the HFWrapper class as the Trainer is now supporting the base T4Rec Model class.
Set the default of the trainer's argument predict_top_k to 0 instead of 10.
Note that getting the top-k predictions is specific to NextItemPredictionTask and the user should explicitly set the parameter in the T4RecTrainingArguments object. If not specified, the method Trainer.predict() returns unsorted predictions for the whole item catalog.
Support multi-task learning in the Trainer class: it accepts any T4Rec model defined with multiple tasks and/or multiple heads.
3. Fix the inference performance of the Transformer-based model trained with masked language modeling (MLM):
At inference, the input sequence is extended by a [MASK] embedding after the last non-padded position to take into account the target position. The hidden representation of the [MASK] position is used to get the next-item prediction scores.
With this fix, the user doesn't need to add a dummy position to the input test data when calling Trainer.predict() or model(test_batch, training=False, testing=False)
4. Update Transformers4Rec to use the new merlin-dataloader package: #547
The NVTabularDataLoader is renamed to MerlinDataLoader to use the loader from merlin-dataloader package.
User can specify the argument data_loader_engine=‘merlin’ in the T4RecTrainingArguments object to use the merlin dataloader. It supports GPU and CPU environments. The alias nvtabular is also kept to ensure backward compatibility.
What’s Changed
⚠ Breaking Changes
Extend trainer class to support all T4Rec prediction tasks @sararb (#564)