We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent 7c39a57 commit 96bb4f4Copy full SHA for 96bb4f4
recipes/full_finetune_distributed.py
@@ -811,7 +811,6 @@ def validate(self) -> float:
811
else float("inf")
812
)
813
814
- self._model.train()
815
return avg_val_loss
816
817
def train(self) -> None:
@@ -848,6 +847,7 @@ def train(self) -> None:
848
847
and self._device.type == "cuda"
849
):
850
torch.cuda.memory._record_memory_history()
+ self._model.train()
851
utils.batch_to_device(batch, self._device)
852
853
# Calculate the number of unmasked tokens in the current batch
0 commit comments