@@ -144,11 +144,15 @@ def _extract_training_state(self) -> Dict[str, Any]:
144144 name : self .state [name ].state_dict () for name in self ._state_components
145145 }
146146 kwargs = dict (dtype = torch .int64 )
147+ iter_state = {
148+ ** get_iterator (self .train_iterator ).state_dict (),
149+ "epoch" : torch .tensor (self .train_iterator .epoch , ** kwargs ),
150+ }
147151 train_state .update (
148152 {
149153 "data_state" : self .dataset .training_state .state_dict (),
150154 "iter_num" : torch .tensor (self .state ["iter_num" ], ** kwargs ),
151- "train_iterator" : get_iterator ( self . train_iterator ). state_dict () ,
155+ "train_iterator" : iter_state ,
152156 }
153157 )
154158 return train_state
@@ -245,9 +249,11 @@ def restore_from_training_state(
245249 elif name in train_state :
246250 raise ValueError (f"{ name } : Contained in train_state, but not in state" )
247251 # Reconstruct the training iterator
252+ iter_state = train_state ["train_iterator" ]
248253 inner_iter = get_iterator (train_iterator )
249- inner_iter .load_state_dict (train_state [ "train_iterator" ] )
254+ inner_iter .load_state_dict (iter_state )
250255 train_iterator ._iterator = inner_iter
256+ train_iterator .epoch = iter_state ["epoch" ].item ()
251257
252258
253259def restore_dataset_from_training_state (
0 commit comments