Skip to content

Commit f7fa2b4

Browse files
committed
Clean up data part for training state, make it generic. Fixes a bug as well (#108)
1 parent 3168e3f commit f7fa2b4

11 files changed

Lines changed: 456 additions & 172 deletions

File tree

keys_values/data/helmet.py

Lines changed: 134 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from typing import List, Optional, Dict, Any, Tuple, Literal
1717

1818
from tokenizers import Tokenizer as HFTokenizer
19+
import torch
1920
from tqdm import tqdm
2021

2122
from keys_values.data.dataloader import MyDataLoader
@@ -26,6 +27,7 @@
2627
)
2728
from keys_values.data.module import (
2829
SequenceLengthFilteredDataModule,
30+
SequenceLengthFilteredDataTrainState,
2931
METADATA_SEQ_LENGTHS_KEY,
3032
METADATA_KEYS,
3133
RawDatasetType,
@@ -40,7 +42,86 @@
4042

4143
METADATA_FNAME = "helmet_metadata.json"
4244

43-
METADATA_TARGET_CHOICE_KEY = "target_choice"
45+
46+
class HelmetDataTrainState(SequenceLengthFilteredDataTrainState):
47+
"""
48+
Also contains the `target_choice` indexes for training, validation and
49+
test split.
50+
"""
51+
52+
def __init__(self):
53+
super().__init__()
54+
self._train_target_choice = None
55+
self._val_target_choice = None
56+
self._test_target_choice = None
57+
58+
@property
59+
def train_target_choice(self) -> Optional[List[int]]:
60+
return self._train_target_choice
61+
62+
@train_target_choice.setter
63+
def train_target_choice(self, value: Optional[List[int]]) -> None:
64+
# `>` is OK, as dataset may be padded after splitting
65+
if len(value) < len(self.train_data_index):
66+
raise ValueError(
67+
f"len(train_target_choice) = {len(value)} < {len(self.train_data_index)} = len(self.train_data_index)"
68+
)
69+
if not all(x >= 0 for x in value):
70+
raise ValueError("All entries of train_target_choice must be >= 0")
71+
self._train_target_choice = value.copy()
72+
73+
@property
74+
def val_target_choice(self) -> Optional[List[int]]:
75+
return self._val_target_choice
76+
77+
@val_target_choice.setter
78+
def val_target_choice(self, value: Optional[List[int]]) -> None:
79+
# `>` is OK, as dataset may be padded after splitting
80+
if len(value) < len(self.val_data_index):
81+
raise ValueError(
82+
f"len(val_target_choice) = {len(value)} < {len(self.val_data_index)} = len(self.val_data_index)"
83+
)
84+
if not all(x >= 0 for x in value):
85+
raise ValueError("All entries of val_target_choice must be >= 0")
86+
self._val_target_choice = value.copy()
87+
88+
@property
89+
def test_target_choice(self) -> Optional[List[int]]:
90+
return self._test_target_choice
91+
92+
@test_target_choice.setter
93+
def test_target_choice(self, value: Optional[List[int]]) -> None:
94+
if not all(x >= 0 for x in value):
95+
raise ValueError("All entries of test_target_choice must be >= 0")
96+
self._test_target_choice = value.copy()
97+
98+
def state_dict(self) -> Dict[str, torch.Tensor]:
99+
kwargs = dict(dtype=torch.int64)
100+
result = super().state_dict()
101+
result.update(
102+
{
103+
f"{name}_target_choice": torch.tensor(value, **kwargs)
104+
for name, value in zip(
105+
("train", "val", "test"),
106+
(
107+
self.train_target_choice,
108+
self.val_target_choice,
109+
self.test_target_choice,
110+
),
111+
)
112+
if value is not None
113+
}
114+
)
115+
return result
116+
117+
def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
118+
super().load_state_dict(state_dict)
119+
train_ind = state_dict.get("train_target_choice")
120+
val_ind = state_dict.get("val_target_choice")
121+
test_ind = state_dict.get("test_target_choice")
122+
self.train_target_choice = None if train_ind is None else train_ind.tolist()
123+
self.val_target_choice = None if val_ind is None else val_ind.tolist()
124+
self.test_target_choice = None if test_ind is None else test_ind.tolist()
44125

45126

46127
class Helmet(SequenceLengthFilteredDataModule):
@@ -121,8 +202,6 @@ def __init__(
121202
self.max_length = max_length
122203
self.dataset_parent_dir = dataset_parent_dir
123204
self.metadata_dir = metadata_dir
124-
self.target_choices = [None, None, None]
125-
self._metadata = None
126205

127206
def _metadata_keys(
128207
self,
@@ -145,12 +224,6 @@ def _get_dataset(self) -> Tuple[RawDatasetType, Optional[RawDatasetType]]:
145224
)
146225
print(f"\nTransforming HELMET '{self.dataset_key}' ({self.max_length}) ...")
147226
metadata = self._load_metadata()
148-
self._metadata = metadata # Needed in :meth:`_create_datasets`
149-
self.target_choices = [
150-
self._get_target_choice(metadata, "train"),
151-
self._get_target_choice(metadata, "val"),
152-
self._get_target_choice(metadata, "test"),
153-
]
154227
train_data, dev_seq_lengths, dev_needs_store = self._transform(
155228
dev_data, split="dev", seq_lengths=self._get_seq_lengths(metadata, "dev")
156229
)
@@ -245,13 +318,6 @@ def _get_seq_lengths(
245318
) -> Optional[List[int]]:
246319
return get_dict(metadata, self._metadata_keys(METADATA_SEQ_LENGTHS_KEY, split))
247320

248-
def _get_target_choice(
249-
self, metadata: Optional[Dict[str, Any]], split: str
250-
) -> Optional[List[int]]:
251-
return get_dict(
252-
metadata, self._metadata_keys(METADATA_TARGET_CHOICE_KEY, split)
253-
)
254-
255321
def _load_metadata(self) -> Optional[Dict[str, Any]]:
256322
if self.metadata_dir is None:
257323
return None
@@ -282,48 +348,73 @@ def _create_datasets(
282348
val_kwargs: Dict[str, Any],
283349
test_kwargs: Optional[Dict[str, Any]],
284350
) -> None:
285-
num_sets = 2
351+
assert self.training_state is not None # Sanity check
352+
if not isinstance(self.training_state, HelmetDataTrainState):
353+
# Must have been created in :meth:`SequenceLengthFilteredDataModule.setup`
354+
if not isinstance(
355+
self.training_state, SequenceLengthFilteredDataTrainState
356+
):
357+
raise TypeError(
358+
f"type(self.training_state) = {type(self.training_state)}: Invalid"
359+
)
360+
# Convert it
361+
new_training_state = HelmetDataTrainState()
362+
new_training_state.initialize(
363+
self.training_state.train_data_index,
364+
self.training_state.val_data_index,
365+
)
366+
self.training_state = new_training_state
367+
else:
368+
for name, value in zip(
369+
("train", "val", "test"),
370+
(
371+
self.training_state.train_target_choice,
372+
self.training_state.val_target_choice,
373+
self.training_state.test_target_choice,
374+
),
375+
):
376+
if value is not None:
377+
print(
378+
f"Loaded {name}_target_choice ({len(value)}) from training state"
379+
)
380+
target_choice = self.training_state.train_target_choice
286381
self.train_dataset = SFTDataset(
287382
**train_kwargs,
288383
mask_prompt=self.mask_prompt,
289384
ignore_index=self.ignore_index,
290-
target_choice=self.target_choices[0],
385+
target_choice=target_choice,
291386
seed=self.seed,
292387
)
388+
if target_choice is None:
389+
print(
390+
f"Sampled train_target_choice ({len(self.train_dataset.target_choice)})"
391+
)
392+
self.training_state.train_target_choice = self.train_dataset.target_choice
393+
target_choice = self.training_state.val_target_choice
293394
self.val_dataset = SFTDataset(
294395
**val_kwargs,
295396
mask_prompt=self.mask_prompt,
296397
ignore_index=self.ignore_index,
297-
target_choice=self.target_choices[1],
398+
target_choice=target_choice,
298399
seed=self.seed,
299400
)
401+
if target_choice is None:
402+
print(f"Sampled val_target_choice ({len(self.val_dataset.target_choice)})")
403+
self.training_state.val_target_choice = self.val_dataset.target_choice
300404
if test_kwargs is not None:
405+
target_choice = self.training_state.test_target_choice
301406
self.test_dataset = SFTDataset(
302407
**test_kwargs,
303408
mask_prompt=self.mask_prompt,
304409
ignore_index=self.ignore_index,
305-
target_choice=self.target_choices[2],
410+
target_choice=target_choice,
306411
seed=self.seed,
307412
)
308-
num_sets += 1
309-
# Update meta-data?
310-
do_store_meta = any(x is None for x in self.target_choices[:num_sets])
311-
if do_store_meta:
312-
for i, (data, split) in enumerate(
313-
zip(
314-
(self.train_dataset, self.val_dataset, self.test_dataset),
315-
("train", "val", "test"),
413+
if target_choice is None:
414+
print(
415+
f"Sampled test_target_choice ({len(self.test_dataset.target_choice)})"
316416
)
317-
):
318-
if self.target_choices[i] is None and data is not None:
319-
new_choices = data.target_choice.copy()
320-
self.target_choices[i] = new_choices
321-
set_dict(
322-
self._metadata,
323-
self._metadata_keys(METADATA_TARGET_CHOICE_KEY, split),
324-
new_choices,
325-
)
326-
self._store_metadata(self._metadata)
417+
self.training_state.test_target_choice = self.test_dataset.target_choice
327418

328419
def _get_collate_fn(self) -> MyDataLoader:
329420
return get_sft_collate_fn(ignore_index=self.ignore_index)
@@ -397,3 +488,8 @@ def smart_lastrec_info(self, tokenizer: HFTokenizer) -> SmartInitialInformation:
397488
max_initial_fraction=max_initial_fraction,
398489
include_end_string=include_end_string,
399490
)
491+
492+
def load_training_state(self, state_dict: Dict[str, torch.Tensor]):
493+
if self.training_state is None:
494+
self.training_state = HelmetDataTrainState()
495+
self.training_state.load_state_dict(state_dict)

0 commit comments

Comments
 (0)