diff --git a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py index 3093c6923..8d0d2e0d9 100644 --- a/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py +++ b/applications/DeepSpeed-Chat/dschat/utils/data/data_utils.py @@ -375,7 +375,9 @@ def create_prompt_dataset(local_rank, torch.save(train_dataset, train_fname) torch.save(eval_dataset, eval_fname) torch.distributed.barrier() - return torch.load(train_fname), torch.load(eval_fname) + return torch.load(train_fname, + weights_only=False), torch.load(eval_fname, + weights_only=False) class DataCollatorReward: