Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -319,6 +319,21 @@ Basic arguments are:
- `eval.micro_batch_size`: Local batch size to be bused for validation. Overrides
`train.micro_batch_size`. This can often be larger, because evaluation needs
less GPU memory than training.
* `training_state_num`: Positive integer or `None`. If given, training states
are stored alongside this number of last recently stored checkpoints. A
stopped training run can be resumed from a checkpoint if the training state
is available as well. See `resume`.
* `resume`: If given, contains directory name for checkpoint from which
training is to be resumed, such as "step-000100" or "final". You can only
resume training from a checkpoint for which a training state has been stored
as well, see `training_state_num`. Resuming a training run is not the same as
starting training from a checkpoint, in that the following are all restored
from the training state on top of the model weights:
- Optimizer state
- Learning rate scheduler state
- Iteration number
- Training/validation dataset split
- Training iterator state

### Full Fine-tuning or LoRA

Expand Down
15 changes: 14 additions & 1 deletion keys_values/data/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,10 @@
# limitations under the License.
from typing import Iterator, Dict, Any, List, Callable

import torch
from torch.utils.data import Dataset

from keys_values.data.iterators import BatchSampler
from keys_values.data.iterators import BatchSampler, SimilarSequenceLengthIterator

Collator = Callable[[List[Dict[str, Any]]], Dict[str, Any]]

Expand All @@ -39,6 +40,16 @@ def __next__(self) -> Dict[str, Any]:
def __iter__(self) -> Iterator[Dict[str, Any]]:
return self

def state_dict(self) -> Dict[str, torch.Tensor]:
if not isinstance(self._batch_iter, SimilarSequenceLengthIterator):
raise NotImplementedError("Only works for SimilarSequenceLengthIterator")
return self._batch_iter.state_dict()

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
if not isinstance(self._batch_iter, SimilarSequenceLengthIterator):
raise NotImplementedError("Only works for SimilarSequenceLengthIterator")
self._batch_iter.load_state_dict(state_dict)


class MyDataLoader:
def __init__(
Expand All @@ -60,6 +71,8 @@ def __init__(
entries

"""
self.dataset = dataset
self.batch_sampler = batch_sampler
self._iter_kwargs = {
"dataset": dataset,
"batch_sampler": batch_sampler,
Expand Down
62 changes: 57 additions & 5 deletions keys_values/data/helmet.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@

METADATA_FNAME = "helmet_metadata.json"

METADATA_TARGET_CHOICE_KEY = "target_choice"


class Helmet(SequenceLengthFilteredDataModule):
"""Data module for HELMET benchmark datasets.
Expand Down Expand Up @@ -119,10 +121,16 @@ def __init__(
self.max_length = max_length
self.dataset_parent_dir = dataset_parent_dir
self.metadata_dir = metadata_dir
self.target_choices = [None, None, None]
self._metadata = None

def _metadata_keys(self, split: str) -> List[str]:
def _metadata_keys(
self,
root_key: str,
split: str,
) -> List[str]:
return [
METADATA_SEQ_LENGTHS_KEY,
root_key,
self.dataset_key,
self.max_length,
self.tokenizer.model_name,
Expand All @@ -137,6 +145,12 @@ def _get_dataset(self) -> Tuple[RawDatasetType, Optional[RawDatasetType]]:
)
print(f"\nTransforming HELMET '{self.dataset_key}' ({self.max_length}) ...")
metadata = self._load_metadata()
self._metadata = metadata # Needed in :meth:`_create_datasets`
self.target_choices = [
self._get_target_choice(metadata, "train"),
self._get_target_choice(metadata, "val"),
self._get_target_choice(metadata, "test"),
]
train_data, dev_seq_lengths, dev_needs_store = self._transform(
dev_data, split="dev", seq_lengths=self._get_seq_lengths(metadata, "dev")
)
Expand All @@ -147,9 +161,17 @@ def _get_dataset(self) -> Tuple[RawDatasetType, Optional[RawDatasetType]]:
if metadata is None:
metadata = dict()
if dev_needs_store:
set_dict(metadata, self._metadata_keys("dev"), dev_seq_lengths)
set_dict(
metadata,
self._metadata_keys(METADATA_SEQ_LENGTHS_KEY, "dev"),
dev_seq_lengths,
)
if eval_needs_store:
set_dict(metadata, self._metadata_keys("eval"), eval_seq_lengths)
set_dict(
metadata,
self._metadata_keys(METADATA_SEQ_LENGTHS_KEY, "eval"),
eval_seq_lengths,
)
self._store_metadata(metadata)
return train_data, test_data

Expand Down Expand Up @@ -221,7 +243,14 @@ def _transform(
def _get_seq_lengths(
self, metadata: Optional[Dict[str, Any]], split: str
) -> Optional[List[int]]:
return get_dict(metadata, self._metadata_keys(split))
return get_dict(metadata, self._metadata_keys(METADATA_SEQ_LENGTHS_KEY, split))

def _get_target_choice(
self, metadata: Optional[Dict[str, Any]], split: str
) -> Optional[List[int]]:
return get_dict(
metadata, self._metadata_keys(METADATA_TARGET_CHOICE_KEY, split)
)

def _load_metadata(self) -> Optional[Dict[str, Any]]:
if self.metadata_dir is None:
Expand Down Expand Up @@ -253,25 +282,48 @@ def _create_datasets(
val_kwargs: Dict[str, Any],
test_kwargs: Optional[Dict[str, Any]],
) -> None:
num_sets = 2
self.train_dataset = SFTDataset(
**train_kwargs,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
target_choice=self.target_choices[0],
seed=self.seed,
)
self.val_dataset = SFTDataset(
**val_kwargs,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
target_choice=self.target_choices[1],
seed=self.seed,
)
if test_kwargs is not None:
self.test_dataset = SFTDataset(
**test_kwargs,
mask_prompt=self.mask_prompt,
ignore_index=self.ignore_index,
target_choice=self.target_choices[2],
seed=self.seed,
)
num_sets += 1
# Update meta-data?
do_store_meta = any(x is None for x in self.target_choices[:num_sets])
if do_store_meta:
for i, (data, split) in enumerate(
zip(
(self.train_dataset, self.val_dataset, self.test_dataset),
("train", "val", "test"),
)
):
if self.target_choices[i] is None and data is not None:
new_choices = data.target_choice.copy()
self.target_choices[i] = new_choices
set_dict(
self._metadata,
self._metadata_keys(METADATA_TARGET_CHOICE_KEY, split),
new_choices,
)
self._store_metadata(self._metadata)

def _get_collate_fn(self) -> MyDataLoader:
return get_sft_collate_fn(ignore_index=self.ignore_index)
Expand Down
83 changes: 82 additions & 1 deletion keys_values/data/iterators.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from typing import Iterator, List, Optional
from typing import Iterator, List, Optional, Dict

import torch
from torch.utils.data import Sampler
Expand All @@ -24,6 +24,18 @@ def batch_size(self) -> int:
raise NotImplementedError


FINGERPRINT_NAMES = (
"pos",
"dataset_size",
"num_next",
"micro_batch_size",
"num_devices",
"rank",
)

_FINGERPRINT_NAME_TO_POS = dict(zip(FINGERPRINT_NAMES, range(len(FINGERPRINT_NAMES))))


class SimilarSequenceLengthIterator(Iterator[List[int]]):
def __init__(
self,
Expand Down Expand Up @@ -155,6 +167,75 @@ def __next__(self) -> List[int]:
def __iter__(self) -> Iterator[List[int]]:
return self

def _fingerprint(self) -> List[int]:
return [
self._pos,
self.dataset_size,
self.num_next,
self.micro_batch_size,
self.num_devices,
self.rank,
]

def _check_fingerprint(self, fp: List[int]) -> int:
if len(fp) != 6:
raise ValueError(f"fp = {fp}: Fingerprint has 6 entries")
assert _FINGERPRINT_NAME_TO_POS["pos"] == 0
fp_curr = self._fingerprint()
for name, elem, elem_curr in zip(FINGERPRINT_NAMES[1:], fp[1:], fp_curr[1:]):
if elem != elem_curr:
raise ValueError(
f"Entry {name} of fingerprint: {elem}, but must be {elem_curr}"
)
return fp[0]

def _encode_partition(self) -> List[int]:
partition_and_lengths = [[len(part)] + part for part in self._partition]
return [x for part in partition_and_lengths for x in part]

def _decode_partition(self, encoded: List[int]):
pos = 0
enc_len = len(encoded)
decoded = []
while pos < enc_len:
sz_part = encoded[pos]
if not (0 < sz_part <= enc_len - pos - 1):
raise ValueError(
"Invalid size entry in encoded partition: "
f"pos = {pos}, sz_part = {sz_part}:\n{encoded}"
)
pos += 1
decoded.append(encoded[pos : (pos + sz_part)])
pos += sz_part
self._partition = decoded

def state_dict(self) -> Dict[str, torch.Tensor]:
kwargs = dict(dtype=torch.int64)
return {
"fingerprint": torch.tensor(self._fingerprint(), **kwargs),
"permutation": torch.tensor(self._permutation, **kwargs),
"partition": torch.tensor(self._encode_partition(), **kwargs),
}

def load_state_dict(self, state_dict: Dict[str, torch.Tensor]):
for name in ("fingerprint", "permutation", "partition"):
if name not in state_dict:
raise ValueError(f"State dict has no key {name}")
pos = self._check_fingerprint(state_dict["fingerprint"].tolist())
self._decode_partition(state_dict["partition"].tolist())
self._permutation = state_dict["permutation"].tolist()
self._pos = pos

@staticmethod
def rank_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> int:
return state_dict["fingerprint"].tolist()[_FINGERPRINT_NAME_TO_POS["rank"]]

@staticmethod
def num_devices_from_state_dict(state_dict: Dict[str, torch.Tensor]) -> int:
return state_dict["fingerprint"].tolist()[
_FINGERPRINT_NAME_TO_POS["num_devices"]
]


class SimilarSequenceLengthSampler(BatchSampler):
"""
Expand Down
8 changes: 4 additions & 4 deletions keys_values/data/load_helmet_dev_eval.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,24 +208,24 @@ def load_helmet_dev_eval(

Returns:
A tuple of (dev_data, eval_data) datasets. Each data instance will contain at least "input", "output", "query_id", "max_length" fields.

"""
dataset_parent_dir = Path(
dataset_parent_dir
).expanduser() # to ensure ~ can also exist in the given path
source_data_dir = dataset_parent_dir.parent
# 1) If the source data does not exisit, download it first
# 1) If the source data does not exist, download it first
if not os.path.exists(source_data_dir):
download_source_data(source_data_dir)

cache_dir = os.path.join(
dataset_parent_dir.parent, f"longtrain/{dataset_key}_{max_length}"
)
cache_dir = dataset_parent_dir.parent / "longtrain" / f"{dataset_key}_{max_length}"

# 2) If cached, load and return
if os.path.isdir(cache_dir):
dsd = load_from_disk(cache_dir) # expects a DatasetDict with dev/val
# Support either naming convention if you ever change it
if "development" in dsd and "evaluation" in dsd:
print(f"Loaded cached datasets (development, evaluation) from {cache_dir}")
return dsd["development"], dsd["evaluation"]
raise ValueError(
f"Cache found at {cache_dir}, but it doesn't contain expected splits."
Expand Down
Loading
Loading