Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
172 changes: 134 additions & 38 deletions keys_values/data/helmet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from typing import List, Optional, Dict, Any, Tuple, Literal

from tokenizers import Tokenizer as HFTokenizer
import torch
from tqdm import tqdm

from keys_values.data.dataloader import MyDataLoader
Expand All @@ -26,6 +27,7 @@
)
from keys_values.data.module import (
SequenceLengthFilteredDataModule,
SequenceLengthFilteredDataTrainState,
METADATA_SEQ_LENGTHS_KEY,
METADATA_KEYS,
RawDatasetType,
Expand All @@ -40,7 +42,86 @@

METADATA_FNAME = "helmet_metadata.json"

METADATA_TARGET_CHOICE_KEY = "target_choice"

class HelmetDataTrainState(SequenceLengthFilteredDataTrainState):
"""
Also contains the `target_choice` indexes for training, validation and
test split.
"""

def __init__(self):
super().__init__()
self._train_target_choice = None
self._val_target_choice = None
self._test_target_choice = None

@property
def train_target_choice(self) -> Optional[List[int]]:
return self._train_target_choice

@train_target_choice.setter
def train_target_choice(self, value: Optional[List[int]]) -> None:
# `>` is OK, as dataset may be padded after splitting
if len(value) < len(self.train_data_index):
raise ValueError(
f"len(train_target_choice) = {len(value)} < {len(self.train_data_index)} = len(self.train_data_index)"
)
if not all(x >= 0 for x in value):
raise ValueError("All entries of train_target_choice must be >= 0")
self._train_target_choice = value.copy()

@property
def val_target_choice(self) -> Optional[List[int]]:
return self._val_target_choice

@val_target_choice.setter
def val_target_choice(self, value: Optional[List[int]]) -> None:
# `>` is OK, as dataset may be padded after splitting
if len(value) < len(self.val_data_index):
raise ValueError(
f"len(val_target_choice) = {len(value)} < {len(self.val_data_index)} = len(self.val_data_index)"
)
if not all(x >= 0 for x in value):
raise ValueError("All entries of val_target_choice must be >= 0")
self._val_target_choice = value.copy()

@property
def test_target_choice(self) -> Optional[List[int]]:
return self._test_target_choice

@test_target_choice.setter
def test_target_choice(self, value: Optional[List[int]]) -> None:
if not all(x >= 0 for x in value):
raise ValueError("All entries of test_target_choice must be >= 0")
self._test_target_choice = value.copy()

def state_dict(self) -> Dict[str, torch.Tensor]:
kwargs = dict(dtype=torch.int64)
result = super().state_dict()
result.update(
{
f"{name}_target_choice": torch.tensor(value, **kwargs)
for name, value in zip(
("train", "val", "test"),
(
self.train_target_choice,
self.val_target_choice,
self.test_target_choice,
),
)
if value is not None
}
)
return result

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
super().load_state_dict(state_dict)
train_ind = state_dict.get("train_target_choice")
val_ind = state_dict.get("val_target_choice")
test_ind = state_dict.get("test_target_choice")
self.train_target_choice = None if train_ind is None else train_ind.tolist()
self.val_target_choice = None if val_ind is None else val_ind.tolist()
self.test_target_choice = None if test_ind is None else test_ind.tolist()


class Helmet(SequenceLengthFilteredDataModule):
Expand Down Expand Up @@ -121,8 +202,6 @@ def __init__(
self.max_length = max_length
self.dataset_parent_dir = dataset_parent_dir
self.metadata_dir = metadata_dir
self.target_choices = [None, None, None]
self._metadata = None

def _metadata_keys(
self,
Expand All @@ -145,12 +224,6 @@ def _get_dataset(self) -> Tuple[RawDatasetType, Optional[RawDatasetType]]:
)
print(f"\nTransforming HELMET '{self.dataset_key}' ({self.max_length}) ...")
metadata = self._load_metadata()
self._metadata = metadata # Needed in :meth:`_create_datasets`
self.target_choices = [
self._get_target_choice(metadata, "train"),
self._get_target_choice(metadata, "val"),
self._get_target_choice(metadata, "test"),
]
train_data, dev_seq_lengths, dev_needs_store = self._transform(
dev_data, split="dev", seq_lengths=self._get_seq_lengths(metadata, "dev")
)
Expand Down Expand Up @@ -245,13 +318,6 @@ def _get_seq_lengths(
) -> Optional[List[int]]:
return get_dict(metadata, self._metadata_keys(METADATA_SEQ_LENGTHS_KEY, split))

def _get_target_choice(
self, metadata: Optional[Dict[str, Any]], split: str
) -> Optional[List[int]]:
return get_dict(
metadata, self._metadata_keys(METADATA_TARGET_CHOICE_KEY, split)
)

def _load_metadata(self) -> Optional[Dict[str, Any]]:
if self.metadata_dir is None:
return None
Expand Down Expand Up @@ -282,48 +348,73 @@ def _create_datasets(
val_kwargs: Dict[str, Any],
test_kwargs: Optional[Dict[str, Any]],
) -> None:
num_sets = 2
assert self.training_state is not None # Sanity check
if not isinstance(self.training_state, HelmetDataTrainState):
# Must have been created in :meth:`SequenceLengthFilteredDataModule.setup`
if not isinstance(
self.training_state, SequenceLengthFilteredDataTrainState
):
raise TypeError(
f"type(self.training_state) = {type(self.training_state)}: Invalid"
)
# Convert it
new_training_state = HelmetDataTrainState()
new_training_state.initialize(
self.training_state.train_data_index,
self.training_state.val_data_index,
)
self.training_state = new_training_state
else:
for name, value in zip(
("train", "val", "test"),
(
self.training_state.train_target_choice,
self.training_state.val_target_choice,
self.training_state.test_target_choice,
),
):
if value is not None:
print(
f"Loaded {name}_target_choice ({len(value)}) from training state"
)
target_choice = self.training_state.train_target_choice
self.train_dataset = SFTDataset(
**train_kwargs,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
target_choice=self.target_choices[0],
target_choice=target_choice,
seed=self.seed,
)
if target_choice is None:
print(
f"Sampled train_target_choice ({len(self.train_dataset.target_choice)})"
)
self.training_state.train_target_choice = self.train_dataset.target_choice
target_choice = self.training_state.val_target_choice
self.val_dataset = SFTDataset(
**val_kwargs,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
target_choice=self.target_choices[1],
target_choice=target_choice,
seed=self.seed,
)
if target_choice is None:
print(f"Sampled val_target_choice ({len(self.val_dataset.target_choice)})")
self.training_state.val_target_choice = self.val_dataset.target_choice
if test_kwargs is not None:
target_choice = self.training_state.test_target_choice
self.test_dataset = SFTDataset(
**test_kwargs,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
target_choice=self.target_choices[2],
target_choice=target_choice,
seed=self.seed,
)
num_sets += 1
# Update meta-data?
do_store_meta = any(x is None for x in self.target_choices[:num_sets])
if do_store_meta:
for i, (data, split) in enumerate(
zip(
(self.train_dataset, self.val_dataset, self.test_dataset),
("train", "val", "test"),
if target_choice is None:
print(
f"Sampled test_target_choice ({len(self.test_dataset.target_choice)})"
)
):
if self.target_choices[i] is None and data is not None:
new_choices = data.target_choice.copy()
self.target_choices[i] = new_choices
set_dict(
self._metadata,
self._metadata_keys(METADATA_TARGET_CHOICE_KEY, split),
new_choices,
)
self._store_metadata(self._metadata)
self.training_state.test_target_choice = self.test_dataset.target_choice

def _get_collate_fn(self) -> MyDataLoader:
return get_sft_collate_fn(ignore_index=self.ignore_index)
Expand Down Expand Up @@ -397,3 +488,8 @@ def smart_lastrec_info(self, tokenizer: HFTokenizer) -> SmartInitialInformation:
max_initial_fraction=max_initial_fraction,
include_end_string=include_end_string,
)

def load_training_state(self, state_dict: Dict[str, torch.Tensor]):
if self.training_state is None:
self.training_state = HelmetDataTrainState()
self.training_state.load_state_dict(state_dict)
Loading
Loading