Skip to content

Commit 945c0bb

Browse files
committed
Fix: Data training state stores epoch as well
1 parent f74c39a commit 945c0bb

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

keys_values/finetune/longcontext_full.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1383,7 +1383,7 @@ def fit(
13831383
num_devices=devices,
13841384
)
13851385
print_message(
1386-
f"Resume training: Continue from iteration {state['iter_num']}",
1386+
f"Resume training: Continue from epoch {train_iterator.epoch}, iteration {state['iter_num']}",
13871387
fabric,
13881388
)
13891389
if training_state is not None:

keys_values/finetune/resume_state.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -144,11 +144,15 @@ def _extract_training_state(self) -> Dict[str, Any]:
144144
name: self.state[name].state_dict() for name in self._state_components
145145
}
146146
kwargs = dict(dtype=torch.int64)
147+
iter_state = {
148+
**get_iterator(self.train_iterator).state_dict(),
149+
"epoch": torch.tensor(self.train_iterator.epoch, **kwargs),
150+
}
147151
train_state.update(
148152
{
149153
"data_state": self.dataset.training_state.state_dict(),
150154
"iter_num": torch.tensor(self.state["iter_num"], **kwargs),
151-
"train_iterator": get_iterator(self.train_iterator).state_dict(),
155+
"train_iterator": iter_state,
152156
}
153157
)
154158
return train_state
@@ -245,9 +249,11 @@ def restore_from_training_state(
245249
elif name in train_state:
246250
raise ValueError(f"{name}: Contained in train_state, but not in state")
247251
# Reconstruct the training iterator
252+
iter_state = train_state["train_iterator"]
248253
inner_iter = get_iterator(train_iterator)
249-
inner_iter.load_state_dict(train_state["train_iterator"])
254+
inner_iter.load_state_dict(iter_state)
250255
train_iterator._iterator = inner_iter
256+
train_iterator.epoch = iter_state["epoch"].item()
251257

252258

253259
def restore_dataset_from_training_state(

0 commit comments

Comments
 (0)