Skip to content

Commit 3b996ee

Browse files
committed
Store training state and resume training from stored state. Fix bug in SFTDataset.__getitem__ (#103)
1 parent df7b2ef commit 3b996ee

14 files changed

Lines changed: 801 additions & 151 deletions

README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,21 @@ Basic arguments are:
319319
- `eval.micro_batch_size`: Local batch size to be bused for validation. Overrides
320320
`train.micro_batch_size`. This can often be larger, because evaluation needs
321321
less GPU memory than training.
322+
* `training_state_num`: Positive integer or `None`. If given, training states
323+
are stored alongside this number of last recently stored checkpoints. A
324+
stopped training run can be resumed from a checkpoint if the training state
325+
is available as well. See `resume`.
326+
* `resume`: If given, contains directory name for checkpoint from which
327+
training is to be resumed, such as "step-000100" or "final". You can only
328+
resume training from a checkpoint for which a training state has been stored
329+
as well, see `training_state_num`. Resuming a training run is not the same as
330+
starting training from a checkpoint, in that the following are all restored
331+
from the training state on top of the model weights:
332+
- Optimizer state
333+
- Learning rate scheduler state
334+
- Iteration number
335+
- Training/validation dataset split
336+
- Training iterator state
322337

323338
### Full Fine-tuning or LoRA
324339

keys_values/data/dataloader.py

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,9 +13,10 @@
1313
# limitations under the License.
1414
from typing import Iterator, Dict, Any, List, Callable
1515

16+
import torch
1617
from torch.utils.data import Dataset
1718

18-
from keys_values.data.iterators import BatchSampler
19+
from keys_values.data.iterators import BatchSampler, SimilarSequenceLengthIterator
1920

2021
Collator = Callable[[List[Dict[str, Any]]], Dict[str, Any]]
2122

@@ -39,6 +40,16 @@ def __next__(self) -> Dict[str, Any]:
3940
def __iter__(self) -> Iterator[Dict[str, Any]]:
4041
return self
4142

43+
def state_dict(self) -> Dict[str, torch.Tensor]:
44+
if not isinstance(self._batch_iter, SimilarSequenceLengthIterator):
45+
raise NotImplementedError("Only works for SimilarSequenceLengthIterator")
46+
return self._batch_iter.state_dict()
47+
48+
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
49+
if not isinstance(self._batch_iter, SimilarSequenceLengthIterator):
50+
raise NotImplementedError("Only works for SimilarSequenceLengthIterator")
51+
self._batch_iter.load_state_dict(state_dict)
52+
4253

4354
class MyDataLoader:
4455
def __init__(
@@ -60,6 +71,8 @@ def __init__(
6071
entries
6172
6273
"""
74+
self.dataset = dataset
75+
self.batch_sampler = batch_sampler
6376
self._iter_kwargs = {
6477
"dataset": dataset,
6578
"batch_sampler": batch_sampler,

keys_values/data/helmet.py

Lines changed: 57 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,8 @@
4040

4141
METADATA_FNAME = "helmet_metadata.json"
4242

43+
METADATA_TARGET_CHOICE_KEY = "target_choice"
44+
4345

4446
class Helmet(SequenceLengthFilteredDataModule):
4547
"""Data module for HELMET benchmark datasets.
@@ -119,10 +121,16 @@ def __init__(
119121
self.max_length = max_length
120122
self.dataset_parent_dir = dataset_parent_dir
121123
self.metadata_dir = metadata_dir
124+
self.target_choices = [None, None, None]
125+
self._metadata = None
122126

123-
def _metadata_keys(self, split: str) -> List[str]:
127+
def _metadata_keys(
128+
self,
129+
root_key: str,
130+
split: str,
131+
) -> List[str]:
124132
return [
125-
METADATA_SEQ_LENGTHS_KEY,
133+
root_key,
126134
self.dataset_key,
127135
self.max_length,
128136
self.tokenizer.model_name,
@@ -137,6 +145,12 @@ def _get_dataset(self) -> Tuple[RawDatasetType, Optional[RawDatasetType]]:
137145
)
138146
print(f"\nTransforming HELMET '{self.dataset_key}' ({self.max_length}) ...")
139147
metadata = self._load_metadata()
148+
self._metadata = metadata # Needed in :meth:`_create_datasets`
149+
self.target_choices = [
150+
self._get_target_choice(metadata, "train"),
151+
self._get_target_choice(metadata, "val"),
152+
self._get_target_choice(metadata, "test"),
153+
]
140154
train_data, dev_seq_lengths, dev_needs_store = self._transform(
141155
dev_data, split="dev", seq_lengths=self._get_seq_lengths(metadata, "dev")
142156
)
@@ -147,9 +161,17 @@ def _get_dataset(self) -> Tuple[RawDatasetType, Optional[RawDatasetType]]:
147161
if metadata is None:
148162
metadata = dict()
149163
if dev_needs_store:
150-
set_dict(metadata, self._metadata_keys("dev"), dev_seq_lengths)
164+
set_dict(
165+
metadata,
166+
self._metadata_keys(METADATA_SEQ_LENGTHS_KEY, "dev"),
167+
dev_seq_lengths,
168+
)
151169
if eval_needs_store:
152-
set_dict(metadata, self._metadata_keys("eval"), eval_seq_lengths)
170+
set_dict(
171+
metadata,
172+
self._metadata_keys(METADATA_SEQ_LENGTHS_KEY, "eval"),
173+
eval_seq_lengths,
174+
)
153175
self._store_metadata(metadata)
154176
return train_data, test_data
155177

@@ -221,7 +243,14 @@ def _transform(
221243
def _get_seq_lengths(
222244
self, metadata: Optional[Dict[str, Any]], split: str
223245
) -> Optional[List[int]]:
224-
return get_dict(metadata, self._metadata_keys(split))
246+
return get_dict(metadata, self._metadata_keys(METADATA_SEQ_LENGTHS_KEY, split))
247+
248+
def _get_target_choice(
249+
self, metadata: Optional[Dict[str, Any]], split: str
250+
) -> Optional[List[int]]:
251+
return get_dict(
252+
metadata, self._metadata_keys(METADATA_TARGET_CHOICE_KEY, split)
253+
)
225254

226255
def _load_metadata(self) -> Optional[Dict[str, Any]]:
227256
if self.metadata_dir is None:
@@ -253,25 +282,48 @@ def _create_datasets(
253282
val_kwargs: Dict[str, Any],
254283
test_kwargs: Optional[Dict[str, Any]],
255284
) -> None:
285+
num_sets = 2
256286
self.train_dataset = SFTDataset(
257287
**train_kwargs,
258288
mask_prompt=self.mask_prompt,
259289
ignore_index=self.ignore_index,
290+
target_choice=self.target_choices[0],
260291
seed=self.seed,
261292
)
262293
self.val_dataset = SFTDataset(
263294
**val_kwargs,
264295
mask_prompt=self.mask_prompt,
265296
ignore_index=self.ignore_index,
297+
target_choice=self.target_choices[1],
266298
seed=self.seed,
267299
)
268300
if test_kwargs is not None:
269301
self.test_dataset = SFTDataset(
270302
**test_kwargs,
271303
mask_prompt=self.mask_prompt,
272304
ignore_index=self.ignore_index,
305+
target_choice=self.target_choices[2],
273306
seed=self.seed,
274307
)
308+
num_sets += 1
309+
# Update meta-data?
310+
do_store_meta = any(x is None for x in self.target_choices[:num_sets])
311+
if do_store_meta:
312+
for i, (data, split) in enumerate(
313+
zip(
314+
(self.train_dataset, self.val_dataset, self.test_dataset),
315+
("train", "val", "test"),
316+
)
317+
):
318+
if self.target_choices[i] is None and data is not None:
319+
new_choices = data.target_choice.copy()
320+
self.target_choices[i] = new_choices
321+
set_dict(
322+
self._metadata,
323+
self._metadata_keys(METADATA_TARGET_CHOICE_KEY, split),
324+
new_choices,
325+
)
326+
self._store_metadata(self._metadata)
275327

276328
def _get_collate_fn(self) -> MyDataLoader:
277329
return get_sft_collate_fn(ignore_index=self.ignore_index)

keys_values/data/iterators.py

Lines changed: 82 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import math
15-
from typing import Iterator, List, Optional
15+
from typing import Iterator, List, Optional, Dict
1616

1717
import torch
1818
from torch.utils.data import Sampler
@@ -24,6 +24,18 @@ def batch_size(self) -> int:
2424
raise NotImplementedError
2525

2626

27+
FINGERPRINT_NAMES = (
28+
"pos",
29+
"dataset_size",
30+
"num_next",
31+
"micro_batch_size",
32+
"num_devices",
33+
"rank",
34+
)
35+
36+
_FINGERPRINT_NAME_TO_POS = dict(zip(FINGERPRINT_NAMES, range(len(FINGERPRINT_NAMES))))
37+
38+
2739
class SimilarSequenceLengthIterator(Iterator[List[int]]):
2840
def __init__(
2941
self,
@@ -155,6 +167,75 @@ def __next__(self) -> List[int]:
155167
def __iter__(self) -> Iterator[List[int]]:
156168
return self
157169

170+
def _fingerprint(self) -> List[int]:
171+
return [
172+
self._pos,
173+
self.dataset_size,
174+
self.num_next,
175+
self.micro_batch_size,
176+
self.num_devices,
177+
self.rank,
178+
]
179+
180+
def _check_fingerprint(self, fp: List[int]) -> int:
181+
if len(fp) != 6:
182+
raise ValueError(f"fp = {fp}: Fingerprint has 6 entries")
183+
assert _FINGERPRINT_NAME_TO_POS["pos"] == 0
184+
fp_curr = self._fingerprint()
185+
for name, elem, elem_curr in zip(FINGERPRINT_NAMES[1:], fp[1:], fp_curr[1:]):
186+
if elem != elem_curr:
187+
raise ValueError(
188+
f"Entry {name} of fingerprint: {elem}, but must be {elem_curr}"
189+
)
190+
return fp[0]
191+
192+
def _encode_partition(self) -> List[int]:
193+
partition_and_lengths = [[len(part)] + part for part in self._partition]
194+
return [x for part in partition_and_lengths for x in part]
195+
196+
def _decode_partition(self, encoded: List[int]):
197+
pos = 0
198+
enc_len = len(encoded)
199+
decoded = []
200+
while pos < enc_len:
201+
sz_part = encoded[pos]
202+
if not (0 < sz_part <= enc_len - pos - 1):
203+
raise ValueError(
204+
"Invalid size entry in encoded partition: "
205+
f"pos = {pos}, sz_part = {sz_part}:\n{encoded}"
206+
)
207+
pos += 1
208+
decoded.append(encoded[pos : (pos + sz_part)])
209+
pos += sz_part
210+
self._partition = decoded
211+
212+
def state_dict(self) -> Dict[str, torch.Tensor]:
213+
kwargs = dict(dtype=torch.int64)
214+
return {
215+
"fingerprint": torch.tensor(self._fingerprint(), **kwargs),
216+
"permutation": torch.tensor(self._permutation, **kwargs),
217+
"partition": torch.tensor(self._encode_partition(), **kwargs),
218+
}
219+
220+
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
221+
for name in ("fingerprint", "permutation", "partition"):
222+
if name not in state_dict:
223+
raise ValueError(f"State dict has no key {name}")
224+
pos = self._check_fingerprint(state_dict["fingerprint"].tolist())
225+
self._decode_partition(state_dict["partition"].tolist())
226+
self._permutation = state_dict["permutation"].tolist()
227+
self._pos = pos
228+
229+
@staticmethod
230+
def rank_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> int:
231+
return state_dict["fingerprint"].tolist()[_FINGERPRINT_NAME_TO_POS["rank"]]
232+
233+
@staticmethod
234+
def num_devices_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> int:
235+
return state_dict["fingerprint"].tolist()[
236+
_FINGERPRINT_NAME_TO_POS["num_devices"]
237+
]
238+
158239

159240
class SimilarSequenceLengthSampler(BatchSampler):
160241
"""

keys_values/data/load_helmet_dev_eval.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -208,24 +208,24 @@ def load_helmet_dev_eval(
208208
209209
Returns:
210210
A tuple of (dev_data, eval_data) datasets. Each data instance will contain at least "input", "output", "query_id", "max_length" fields.
211+
211212
"""
212213
dataset_parent_dir = Path(
213214
dataset_parent_dir
214215
).expanduser() # to ensure ~ can also exist in the given path
215216
source_data_dir = dataset_parent_dir.parent
216-
# 1) If the source data does not exisit, download it first
217+
# 1) If the source data does not exist, download it first
217218
if not os.path.exists(source_data_dir):
218219
download_source_data(source_data_dir)
219220

220-
cache_dir = os.path.join(
221-
dataset_parent_dir.parent, f"longtrain/{dataset_key}_{max_length}"
222-
)
221+
cache_dir = dataset_parent_dir.parent / "longtrain" / f"{dataset_key}_{max_length}"
223222

224223
# 2) If cached, load and return
225224
if os.path.isdir(cache_dir):
226225
dsd = load_from_disk(cache_dir) # expects a DatasetDict with dev/val
227226
# Support either naming convention if you ever change it
228227
if "development" in dsd and "evaluation" in dsd:
228+
print(f"Loaded cached datasets (development, evaluation) from {cache_dir}")
229229
return dsd["development"], dsd["evaluation"]
230230
raise ValueError(
231231
f"Cache found at {cache_dir}, but it doesn't contain expected splits."

0 commit comments

Comments
 (0)